diff --git a/.bazelrc b/.bazelrc index a92758c35a81..421506bbde53 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,3 +1,7 @@ +# Common settings +common --enable_bzlmod +build --enable_bzlmod + # Basic build settings build --jobs 128 build --cxxopt='-std=gnu++14' diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 78b280af512a..41a167095fef 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -380,7 +380,7 @@ jobs: env: CC: gcc-9 CXX: g++-9 - BAZEL_DEFINES: --define=xnn_enable_avxvnni=false --define=xnn_enable_avxvnniint8=false --define=xnn_enable_avx512amx=false --define=xnn_enable_avx512fp16=false + BAZEL_DEFINES: --define=xnn_enable_avxvnni=false --define=xnn_enable_avx256vnni=false --define=xnn_enable_avxvnniint8=false --define=xnn_enable_avx512amx=false --define=xnn_enable_avx512fp16=false steps: - uses: actions/checkout@v4 - name: Update apt @@ -474,14 +474,16 @@ jobs: timeout-minutes: 60 steps: - uses: actions/checkout@v4 - - name: Install gcc-13 - # Pull in gcc-13 from the ubuntu-23.10 repository since it is not available - # for ubuntu-22.04. + - name: Add repository ppa:ubuntu-toolchain-r/test for gcc-13 and g++-13 working-directory: ${{ github.workspace }} run: | sudo add-apt-repository ppa:ubuntu-toolchain-r/test sudo apt update - sudo apt install gcc-13 g++-13 + - name: Install gcc-13 (cached) + uses: awalsh128/cache-apt-pkgs-action@latest + with: + packages: gcc-13 g++-13 + version: 1.0 - name: Restore bazel cache uses: actions/cache/restore@v4 with: diff --git a/BUILD.bazel b/BUILD.bazel index 71ef0e2c1726..a5af5832e27b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -113,6 +113,7 @@ MICROKERNEL_DEFS = [ "src/f32-pavgpool/f32-pavgpool-minmax.h", "src/f32-qs8-vcvt/f32-qs8-vcvt.h", "src/f32-qu8-vcvt/f32-qu8-vcvt.h", + "src/f32-raddextexp/f32-raddextexp.h", "src/f32-vabs/f32-vabs.h", "src/f32-vbinary/f32-vadd.h", "src/f32-vbinary/f32-vaddc.h", @@ -192,7 +193,6 @@ MICROKERNEL_DEFS = [ "src/s8-ibilinear/s8-ibilinear.h", "src/s8-maxpool/s8-maxpool-minmax.h", "src/s8-vclamp/s8-vclamp.h", - "src/s32-f32-vcvt/s32-f32-vcvt.h", "src/u8-ibilinear/u8-ibilinear.h", "src/u8-maxpool/u8-maxpool-minmax.h", "src/u8-vclamp/u8-vclamp.h", @@ -993,6 +993,7 @@ xnnpack_cc_library( ":datatype", ":fp16", ":indirection", + ":internal", ":logging", ":math", ":microkernel_configs", diff --git a/CMakeLists.txt b/CMakeLists.txt index 856a47d1a665..0e234597c53b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,12 @@ SET(CMAKE_CXX_EXTENSIONS NO) # ---[ Options. SET(XNNPACK_LIBRARY_TYPE "default" CACHE STRING "Type of library (shared, static, or default) to build") SET_PROPERTY(CACHE XNNPACK_LIBRARY_TYPE PROPERTY STRINGS default static shared) -OPTION(XNNPACK_ENABLE_ASSEMBLY "Build XNNPACK with assembly micro-kernels" ON) +IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") + # Disable assembly when using MSVC until support is added. + OPTION(XNNPACK_ENABLE_ASSEMBLY "Build XNNPACK with assembly micro-kernels" OFF) +ELSE() + OPTION(XNNPACK_ENABLE_ASSEMBLY "Build XNNPACK with assembly micro-kernels" ON) +ENDIF() OPTION(XNNPACK_ENABLE_MEMOPT "Build XNNPACK with optimized memory allocation scheme" ON) OPTION(XNNPACK_ENABLE_SPARSE "Build XNNPACK with graph rewriting for sparse inference" ON) OPTION(XNNPACK_ENABLE_GEMM_M_SPECIALIZATION "Build XNNPACK with support for selecting microkernel with different MR" ON) @@ -430,7 +435,6 @@ SET(OPERATOR_SRCS src/operators/average-pooling-nhwc.c src/operators/batch-matrix-multiply-nc.c src/operators/binary-elementwise-nd.c - src/operators/channel-shuffle-nc.c src/operators/constant-pad-nd.c src/operators/convolution-nchw.c src/operators/convolution-nhwc.c @@ -523,7 +527,6 @@ SET(XNNPACK_SRCS src/configs/xx-fill-config.c src/configs/xx-pad-config.c src/configs/x8-lut-config.c - src/configs/zip-config.c src/init.c src/params.c "${PROJECT_BINARY_DIR}/build_identifier.c") @@ -660,6 +663,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$") LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_F16C_MICROKERNEL_SRCS}) LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_FMA3_MICROKERNEL_SRCS}) LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX2_MICROKERNEL_SRCS}) + IF(XNNPACK_ENABLE_ASSEMBLY AND XNNPACK_TARGET_PROCESSOR MATCHES "x86_64") + LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AMD64_ASM_MICROKERNEL_SRCS}) + ENDIF() IF(XNNPACK_ENABLE_AVX512AMX) LIST(APPEND PROD_MICROKERNEL_SRCS ${PROD_AVX512AMX_MICROKERNEL_SRCS}) ENDIF() @@ -707,6 +713,9 @@ IF(XNNPACK_TARGET_PROCESSOR MATCHES "^x86(_64)?$") LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_F16C_MICROKERNEL_SRCS}) LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_FMA3_MICROKERNEL_SRCS}) LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX2_MICROKERNEL_SRCS}) + IF(XNNPACK_ENABLE_ASSEMBLY AND XNNPACK_TARGET_PROCESSOR MATCHES "x86_64") + LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AMD64_ASM_MICROKERNEL_SRCS}) + ENDIF() IF(XNNPACK_ENABLE_AVX512AMX) LIST(APPEND NON_PROD_MICROKERNEL_SRCS ${NON_PROD_AVX512AMX_MICROKERNEL_SRCS}) ENDIF() @@ -1219,6 +1228,8 @@ IF(XNNPACK_BUILD_TESTS) # Helper libraries ADD_LIBRARY(next-prime STATIC test/next_prime.cc) + ADD_LIBRARY(runtime-flags STATIC test/runtime-flags.cc) + TARGET_LINK_LIBRARIES(runtime-flags PRIVATE GTest::gtest) ADD_LIBRARY(gemm-microkernel-tester STATIC test/gemm-microkernel-tester.cc) TARGET_INCLUDE_DIRECTORIES(gemm-microkernel-tester PRIVATE include src test) @@ -1279,6 +1290,7 @@ IF(XNNPACK_BUILD_TESTS) microparams-init next-prime pthreadpool + runtime-flags XNNPACK) ADD_SHARDED_TEST(${TEST}-test 10) ENDFOREACH() @@ -1300,6 +1312,7 @@ IF(XNNPACK_BUILD_TESTS) GTest::gmock GTest::gtest GTest::gtest_main + runtime-flags XNNPACK) ADD_SHARDED_TEST(${TEST}-test 10) ENDFOREACH() @@ -1315,6 +1328,7 @@ IF(XNNPACK_BUILD_TESTS) GTest::gtest GTest::gtest_main datatype + runtime-flags unary-ops XNNPACK) ADD_TEST(NAME unary-elementwise-nc-test COMMAND unary-elementwise-nc-test) @@ -1335,6 +1349,7 @@ IF(XNNPACK_BUILD_TESTS) GTest::gmock GTest::gtest GTest::gtest_main + runtime-flags XNNPACK) ADD_TEST(NAME ${TEST}-test COMMAND ${TEST}-test) ENDFOREACH() @@ -1381,6 +1396,7 @@ IF(XNNPACK_BUILD_TESTS) datatype subgraph logging + runtime-flags unary-ops XNNPACK) ADD_TEST(NAME ${TEST}-test COMMAND ${TEST}-test) @@ -1399,6 +1415,7 @@ IF(XNNPACK_BUILD_TESTS) GTest::gmock GTest::gtest GTest::gtest_main + runtime-flags subgraph XNNPACK) ADD_TEST(NAME ${TEST}-test COMMAND ${TEST}-test) @@ -1464,12 +1481,10 @@ IF(XNNPACK_BUILD_TESTS) x32-packw x32-packx x32-unpool - x32-zip x8-lut x8-packw qs8-packw qs8-qc4w-packw - x8-zip xN-transpose xx-fill xx-pad) @@ -1557,6 +1572,7 @@ IF(XNNPACK_BUILD_TESTS) qd8-f32-qc4w-gemm-minmax qd8-f32-qc8w-igemm-minmax qp8-f32-qc4w-gemm-minmax + qp8-f32-qc8w-gemm-minmax qp8-f32-qb4w-gemm-minmax qs8-qc8w-gemm-minmax-fp32 qs8-qc8w-igemm-minmax-fp32 @@ -1681,7 +1697,6 @@ IF(XNNPACK_BUILD_TESTS) f32-f16-vcvt f32-qs8-vcvt f32-qu8-vcvt - s32-f32-vcvt qs8-f16-vcvt qs8-f32-vcvt qs8-vcvt @@ -1857,7 +1872,7 @@ IF(XNNPACK_BUILD_BENCHMARKS) TARGET_LINK_LIBRARIES(models PRIVATE XNNPACK) ADD_EXECUTABLE(bench-models bench/models/benchmark.cc) - TARGET_INCLUDE_DIRECTORIES(bench-models PRIVATE bench) + TARGET_INCLUDE_DIRECTORIES(bench-models PRIVATE bench ${GOOGLEBENCHMARK_SOURCE_DIR}) TARGET_LINK_LIBRARIES(bench-models PRIVATE bench-utils benchmark::benchmark @@ -1867,7 +1882,6 @@ IF(XNNPACK_BUILD_BENCHMARKS) # ---[ Build operator-level microbenchmarks SET(LIBRARY_OPERATOR_BENCHMARKS average-pooling - channel-shuffle convolution deconvolution max-pooling @@ -1936,6 +1950,7 @@ IF(XNNPACK_BUILD_BENCHMARKS) qd8-f32-qc4w-gemm qd8-f32-qc8w-gemm qp8-f32-qc4w-gemm + qp8-f32-qc8w-gemm qp8-f32-qb4w-gemm qs8-dwconv qs8-gemm diff --git a/WORKSPACE b/MODULE.bazel similarity index 53% rename from WORKSPACE rename to MODULE.bazel index f4c8ba82ca8e..a411e210c635 100644 --- a/WORKSPACE +++ b/MODULE.bazel @@ -1,49 +1,35 @@ -workspace(name = "xnnpack") - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +## MODULE.bazel +module( + name = "xnnpack", +) # Bazel rule definitions -http_archive( - name = "rules_cc", - sha256 = "3868eab488bd5be37a6acedbd222a196bea14408a2857916f33cce7b4780897d", - strip_prefix = "rules_cc-5e848c1434d3458018734238dbc4781f43992ea5", - urls = [ - "https://github.com/bazelbuild/rules_cc/archive/5e848c1434d3458018734238dbc4781f43992ea5.zip", - ], -) +bazel_dep(name = "rules_cc", version = "0.1.0") +bazel_dep(name = "rules_python", version = "1.0.0") -# Bazel Python rule definitions. -http_archive( - name = "rules_python", - sha256 = "4912ced70dc1a2a8e4b86cec233b192ca053e82bc72d877b98e126156e8f228d", - strip_prefix = "rules_python-0.32.2", - urls = [ - "https://github.com/bazelbuild/rules_python/releases/download/0.32.2/rules_python-0.32.2.tar.gz", - ], +pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") +pip.parse( + hub_name = "pip", + python_version = "3.11", + requirements_lock = "//:requirements_lock.txt", ) - -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() +use_repo(pip, "pip") # Bazel Skylib. -http_archive( - name = "bazel_skylib", - sha256 = "f7be3474d42aae265405a592bb7da8e171919d74c16f082a5457840f06054728", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.2.1/bazel-skylib-1.2.1.tar.gz", - ], -) +bazel_dep(name = "bazel_skylib", version = "1.7.1") # Bazel Platforms -http_archive( - name = "platforms", - sha256 = "5308fc1d8865406a49427ba24a9ab53087f17f5266a7aabbfc28823f3916e1ca", - urls = ["https://github.com/bazelbuild/platforms/releases/download/0.0.6/platforms-0.0.6.tar.gz"], -) +bazel_dep(name = "platforms", version = "0.0.10") + +# TODO: some (most? all?) of the http_archive() calls below could become bazel_dep() calls, +# but it would require verifying that the semver provided by the Bazel registry matches the hash +# that we expect in CMake; it's not clear that it is a big win to do so given the modest +# complexity of our deps, so I'm leaving it like this for now to ensure that the Bazel and CMake +# builds are using identical dependencies. + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -# LINT.IfChange +# LINT.IfChange(googletest) # Google Test framework, used by most unit-tests. http_archive( name = "com_google_googletest", @@ -53,7 +39,7 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadGoogleTest.cmake) -# LINT.IfChange +# LINT.IfChange(benchmark) # Google Benchmark library, used in micro-benchmarks. http_archive( name = "com_google_benchmark", @@ -63,7 +49,7 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadGoogleBenchmark.cmake) -# LINT.IfChange +# LINT.IfChange(FXdiv) # FXdiv library, used for repeated integer division by the same factor http_archive( name = "FXdiv", @@ -73,17 +59,17 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadFXdiv.cmake) -# LINT.IfChange +# LINT.IfChange(pthreadpool) # pthreadpool library, used for parallelization http_archive( name = "pthreadpool", - sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95", - strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8", - urls = ["https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"], + sha256 = "9f1baba9e97df8abc792eeaa2a8f0e0d29e507db1b4c1a8210868c889eb449b5", + strip_prefix = "pthreadpool-39df650e19d4f6382e246c29d6819b1ce6ee0b24", + urls = ["https://github.com/google/pthreadpool/archive/39df650e19d4f6382e246c29d6819b1ce6ee0b24.zip"], ) # LINT.ThenChange(cmake/DownloadPThreadPool.cmake) -# LINT.IfChange +# LINT.IfChange(cpuinfo) # cpuinfo library, used for detecting processor characteristics http_archive( name = "cpuinfo", @@ -95,14 +81,14 @@ http_archive( ) # LINT.ThenChange(cmake/DownloadCpuinfo.cmake) -# LINT.IfChange +# LINT.IfChange(kleidiai) # KleidiAI library, used for ARM microkernels. http_archive( name = "KleidiAI", - sha256 = "ad37707084a6d4ff41be10cbe8540c75bea057ba79d0de6c367c1bfac6ba0852", - strip_prefix = "kleidiai-40a926833857fb64786e02f97703e42b1537cb57", + sha256 = "8ba8cdb9f945941174d34d10eb4ad158ad1cbc1aef259de5ad992b0bbe85861f", + strip_prefix = "kleidiai-7e8c4baf953227fa447a2f345e5d6491a504aa56", urls = [ - "https://gitlab.arm.com/kleidi/kleidiai/-/archive/40a926833857fb64786e02f97703e42b1537cb57/kleidiai-40a926833857fb64786e02f97703e42b1537cb57.zip" + "https://gitlab.arm.com/kleidi/kleidiai/-/archive/7e8c4baf953227fa447a2f345e5d6491a504aa56/kleidiai-7e8c4baf953227fa447a2f345e5d6491a504aa56.zip", ], ) # LINT.ThenChange(cmake/DownloadKleidiAI.cmake) diff --git a/bench/BUILD.bazel b/bench/BUILD.bazel index 6c6372cefad5..93380b3a908a 100644 --- a/bench/BUILD.bazel +++ b/bench/BUILD.bazel @@ -218,7 +218,22 @@ xnnpack_benchmark( ) xnnpack_benchmark( - name = "qp8_f32_qb4w_gemm", + name = "qp8_f32_qc8w_gemm_bench", + srcs = [ + "qp8-f32-qc8w-gemm.cc", + ], + defines = xnnpack_kleidiai_defines(), + tags = xnnpack_slow_benchmark_tags(), + deps = MICROKERNEL_BENCHMARK_DEPS + [ + ":gemm_benchmark", + "//:isa_checks", + ] + xnnpack_if_kleidiai_enabled([ + "@KleidiAI//kai/ukernels/matmul", + ]), +) + +xnnpack_benchmark( + name = "qp8_f32_qb4w_gemm_bench", srcs = ["qp8-f32-qb4w-gemm.cc"], defines = xnnpack_kleidiai_defines(), tags = xnnpack_slow_benchmark_tags(), @@ -593,12 +608,6 @@ xnnpack_benchmark( ], ) -xnnpack_benchmark( - name = "channel_shuffle_bench", - srcs = ["channel-shuffle.cc"], - deps = OPERATOR_BENCHMARK_DEPS, -) - xnnpack_benchmark( name = "convolution_bench", srcs = ["convolution.cc"], diff --git a/bench/batch-matrix-multiply.cc b/bench/batch-matrix-multiply.cc index df7606c42131..38620bdb0ff2 100644 --- a/bench/batch-matrix-multiply.cc +++ b/bench/batch-matrix-multiply.cc @@ -35,6 +35,10 @@ #include "tensorflow/lite/version.h" #endif // BENCHMARK_TENSORFLOW_LITE +namespace { +static const size_t kMinIterations = 10; +} // namespace + // Pthreadpool-compatible function to wipe the cache in each thread. void PthreadpoolClearL2Cache(void* context, size_t id) { #if XNN_ENABLE_CPUINFO @@ -122,16 +126,18 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, return; } - for (auto _ : state) { - state.PauseTiming(); - pthreadpool_parallelize_1d(threadpool, PthreadpoolClearL2Cache, nullptr, - num_threads, 0); - state.ResumeTiming(); - - status = xnn_run_operator(op, threadpool); - if (status != xnn_status_success) { - state.SkipWithError("failed to run FP32 BatchMatrixMultiply operator"); - return; + while (state.KeepRunningBatch(kMinIterations)) { + for (int iter = 0; iter < kMinIterations; iter++) { + state.PauseTiming(); + pthreadpool_parallelize_1d(threadpool, PthreadpoolClearL2Cache, nullptr, + num_threads, 0); + state.ResumeTiming(); + + status = xnn_run_operator(op, threadpool); + if (status != xnn_status_success) { + state.SkipWithError("failed to run FP32 BatchMatrixMultiply operator"); + return; + } } } @@ -231,17 +237,19 @@ void xnnpack_batch_matrix_multiply_qd8_f32_qc8w(benchmark::State& state, return; } - for (auto _ : state) { - state.PauseTiming(); - pthreadpool_parallelize_1d(threadpool, PthreadpoolClearL2Cache, nullptr, - num_threads, 0); - state.ResumeTiming(); - - status = xnn_run_operator(op, threadpool); - if (status != xnn_status_success) { - state.SkipWithError( - "failed to run QD8_F32_QC8W BatchMatrixMultiply operator"); - return; + while (state.KeepRunningBatch(kMinIterations)) { + for (int iter = 0; iter < kMinIterations; iter++) { + state.PauseTiming(); + pthreadpool_parallelize_1d(threadpool, PthreadpoolClearL2Cache, nullptr, + num_threads, 0); + state.ResumeTiming(); + + status = xnn_run_operator(op, threadpool); + if (status != xnn_status_success) { + state.SkipWithError( + "failed to run QD8_F32_QC8W BatchMatrixMultiply operator"); + return; + } } } @@ -378,10 +386,12 @@ void tflite_batch_matrix_multiply_f32(benchmark::State& state, interpreter->typed_tensor(1) + batch_size * k * n, std::ref(f32rng)); - for (auto _ : state) { - if (interpreter->Invoke() != kTfLiteOk) { - state.SkipWithError("failed to invoke TFLite interpreter"); - return; + while (state.KeepRunningBatch(kMinIterations)) { + for (int iter = 0; iter < kMinIterations; iter++) { + if (interpreter->Invoke() != kTfLiteOk) { + state.SkipWithError("failed to invoke TFLite interpreter"); + return; + } } } diff --git a/bench/binary.cc b/bench/binary.cc index 5764f02b4fe0..6b77b1bf1162 100644 --- a/bench/binary.cc +++ b/bench/binary.cc @@ -33,7 +33,7 @@ void init_params(xnn_binary_operator op_type, xnn_datatype datatype, xnn_binary_params& params, xnn_quantization_params& input_quantization, xnn_quantization_params& output_quantization) { - switch (op_type) { + switch (datatype) { case xnn_datatype_qint8: input_quantization = {0, 1.0f / 128.0f}; output_quantization = {128, 1.0f / 128.0f}; diff --git a/bench/channel-shuffle.cc b/bench/channel-shuffle.cc deleted file mode 100644 index 0a9f820ba3a8..000000000000 --- a/bench/channel-shuffle.cc +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include - -#include "xnnpack.h" - -#include -#include "utils.h" -#include "xnnpack/buffer.h" - - -static void channel_shuffle_x8(benchmark::State& state, const char* net) { - const size_t batch_size = static_cast(state.range(0)); - const size_t groups = static_cast(state.range(1)); - const size_t group_channels = static_cast(state.range(2)); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - - xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(uint8_t) + batch_size * groups * group_channels); - xnnpack::Buffer output(batch_size * groups * group_channels); - xnnpack::fill_uniform_random_bits(input.data(), input.size(), rng); - - xnn_status status = xnn_initialize(nullptr /* allocator */); - if (status != xnn_status_success) { - state.SkipWithError("failed to initialize XNNPACK"); - return; - } - - xnn_operator_t channel_shuffle_op = nullptr; - status = xnn_create_channel_shuffle_nc_x8( - groups, group_channels, - groups * group_channels /* input stride */, - groups * group_channels /* output stride */, - 0 /* flags */, &channel_shuffle_op); - if (status != xnn_status_success || channel_shuffle_op == nullptr) { - state.SkipWithError("failed to create X8 Channel Shuffle operator"); - return; - } - - status = xnn_reshape_channel_shuffle_nc_x8( - channel_shuffle_op, - batch_size, - /*threadpool=*/nullptr); - if (status != xnn_status_success) { - state.SkipWithError("failed to reshape X8 Channel Shuffle operator"); - return; - } - - status = xnn_setup_channel_shuffle_nc_x8( - channel_shuffle_op, - input.data(), output.data()); - if (status != xnn_status_success) { - state.SkipWithError("failed to setup X8 Channel Shuffle operator"); - return; - } - - for (auto _ : state) { - status = xnn_run_operator(channel_shuffle_op, /*threadpool=*/nullptr); - if (status != xnn_status_success) { - state.SkipWithError("failed to run X8 Channel Shuffle operator"); - return; - } - } - - status = xnn_delete_operator(channel_shuffle_op); - if (status != xnn_status_success) { - state.SkipWithError("failed to delete X8 Channel Shuffle operator"); - return; - } - - const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); - if (cpu_frequency != 0) { - state.counters["cpufreq"] = cpu_frequency; - } - - const size_t elements_per_iteration = batch_size * groups * group_channels; - state.counters["elements"] = - benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); - - const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(uint8_t); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); -} - -static void channel_shuffle_x32(benchmark::State& state, const char* net) { - const size_t batch_size = static_cast(state.range(0)); - const size_t groups = static_cast(state.range(1)); - const size_t group_channels = static_cast(state.range(2)); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto f32rng = std::bind(std::uniform_real_distribution(), std::ref(rng)); - - xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(float) + batch_size * groups * group_channels); - xnnpack::Buffer output(batch_size * groups * group_channels); - std::generate(input.begin(), input.end(), std::ref(f32rng)); - - xnn_status status = xnn_initialize(nullptr /* allocator */); - if (status != xnn_status_success) { - state.SkipWithError("failed to initialize XNNPACK"); - return; - } - - xnn_operator_t channel_shuffle_op = nullptr; - status = xnn_create_channel_shuffle_nc_x32( - groups, group_channels, - groups * group_channels /* input stride */, - groups * group_channels /* output stride */, - 0 /* flags */, &channel_shuffle_op); - if (status != xnn_status_success || channel_shuffle_op == nullptr) { - state.SkipWithError("failed to create X32 Channel Shuffle operator"); - return; - } - - status = xnn_reshape_channel_shuffle_nc_x32( - channel_shuffle_op, - batch_size, - /*threadpool=*/nullptr); - if (status != xnn_status_success) { - state.SkipWithError("failed to reshape X32 Channel Shuffle operator"); - return; - } - - status = xnn_setup_channel_shuffle_nc_x32( - channel_shuffle_op, - input.data(), output.data()); - if (status != xnn_status_success) { - state.SkipWithError("failed to setup X32 Channel Shuffle operator"); - return; - } - - for (auto _ : state) { - status = xnn_run_operator(channel_shuffle_op, /*threadpool=*/nullptr); - if (status != xnn_status_success) { - state.SkipWithError("failed to run X32 Channel Shuffle operator"); - return; - } - } - - status = xnn_delete_operator(channel_shuffle_op); - if (status != xnn_status_success) { - state.SkipWithError("failed to delete X32 Channel Shuffle operator"); - return; - } - - const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); - if (cpu_frequency != 0) { - state.counters["cpufreq"] = cpu_frequency; - } - - const size_t elements_per_iteration = batch_size * groups * group_channels; - state.counters["elements"] = - benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); - - const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); -} - -static void ShuffleNetV1G2Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 ********/ - /* H W G CG */ - b->Args({56 * 56, 2, 25}); - b->Args({28 * 28, 2, 25}); - - /******** Stage 3 ********/ - /* H W G CG */ - b->Args({28 * 28, 2, 50}); - b->Args({14 * 14, 2, 50}); - - /******** Stage 4 ********/ - /* H W G CG */ - b->Args({14 * 14, 2, 100}); - b->Args({ 7 * 7, 2, 100}); -} - -static void ShuffleNetV1G3Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 *******/ - /* H W G CG */ - b->Args({56 * 56, 3, 20}); - b->Args({28 * 28, 3, 20}); - - /******** Stage 3 *******/ - /* H W G CG */ - b->Args({28 * 28, 3, 40}); - b->Args({14 * 14, 3, 40}); - - /******** Stage 4 *******/ - /* H W G CG */ - b->Args({14 * 14, 3, 80}); - b->Args({ 7 * 7, 3, 80}); -} - -static void ShuffleNetV1G4Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 *******/ - /* H W G CG */ - b->Args({56 * 56, 4, 17}); - b->Args({28 * 28, 4, 17}); - - /******** Stage 3 *******/ - /* H W G CG */ - b->Args({28 * 28, 4, 34}); - b->Args({14 * 14, 4, 34}); - - /******** Stage 4 *******/ - /* H W G CG */ - b->Args({14 * 14, 4, 68}); - b->Args({ 7 * 7, 4, 68}); -} - -static void ShuffleNetV1G8Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 *******/ - /* H W G CG */ - b->Args({56 * 56, 8, 12}); - b->Args({28 * 28, 8, 12}); - - /******** Stage 3 *******/ - /* H W G CG */ - b->Args({28 * 28, 8, 24}); - b->Args({14 * 14, 8, 24}); - - /******** Stage 4 *******/ - /* H W G CG */ - b->Args({14 * 14, 8, 48}); - b->Args({ 7 * 7, 8, 48}); -} - -static void ShuffleNetV2x0_5Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 *******/ - /* H W G CG */ - b->Args({28 * 28, 2, 24}); - - /******** Stage 3 *******/ - /* H W G CG */ - b->Args({14 * 14, 2, 48}); - - /******** Stage 4 *******/ - /* H W G CG */ - b->Args({ 7 * 7, 2, 96}); -} - -static void ShuffleNetV2x1_0Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 ********/ - /* H W G CG */ - b->Args({28 * 28, 2, 58}); - - /******** Stage 3 ********/ - /* H W G CG */ - b->Args({14 * 14, 2, 116}); - - /******** Stage 4 ********/ - /* H W G CG */ - b->Args({ 7 * 7, 2, 232}); -} - -static void ShuffleNetV2x1_5Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 ********/ - /* H W G CG */ - b->Args({28 * 28, 2, 88}); - - /******** Stage 3 ********/ - /* H W G CG */ - b->Args({14 * 14, 2, 176}); - - /******** Stage 4 ********/ - /* H W G CG */ - b->Args({ 7 * 7, 2, 352}); -} - -static void ShuffleNetV2x2_0Arguments(benchmark::internal::Benchmark* b) -{ - b->ArgNames({"N", "G", "GC"}); - - /******** Stage 2 ********/ - /* H W G CG */ - b->Args({28 * 28, 2, 122}); - - /******** Stage 3 ********/ - /* H W G CG */ - b->Args({14 * 14, 2, 244}); - - /******** Stage 4 ********/ - /* H W G CG */ - b->Args({ 7 * 7, 2, 488}); -} - -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x05, "ShuffleNet v2 x0.5")->Apply(ShuffleNetV2x0_5Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x10, "ShuffleNet v2 x1.0")->Apply(ShuffleNetV2x1_0Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x15, "ShuffleNet v2 x1.5")->Apply(ShuffleNetV2x1_5Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x20, "ShuffleNet v2 x2.0")->Apply(ShuffleNetV2x2_0Arguments)->UseRealTime(); - -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v2_x05, "ShuffleNet v2 x0.5")->Apply(ShuffleNetV2x0_5Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v2_x10, "ShuffleNet v2 x1.0")->Apply(ShuffleNetV2x1_0Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v2_x15, "ShuffleNet v2 x1.5")->Apply(ShuffleNetV2x1_5Arguments)->UseRealTime(); -BENCHMARK_CAPTURE(channel_shuffle_x32, shufflenet_v2_x20, "ShuffleNet v2 x2.0")->Apply(ShuffleNetV2x2_0Arguments)->UseRealTime(); - -#ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); -#endif diff --git a/bench/f32-gemm-minmax.cc b/bench/f32-gemm-minmax.cc index 5ee34e2f4106..f10a4f901f5f 100644 --- a/bench/f32-gemm-minmax.cc +++ b/bench/f32-gemm-minmax.cc @@ -1286,6 +1286,306 @@ #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 +#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + static void f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/9, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/10, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/11, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/9, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/10, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/11, /*nr=*/32, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast) +#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) static void f32_gemm_minmax_ukernel_1x16__avx512f_broadcast(benchmark::State& state, const char* net) { GEMMBenchmark(state, @@ -2187,6 +2487,116 @@ BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm) + static void f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane) + + static void f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + benchmark::utils::CheckNEONFMA); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane) + static void f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64(benchmark::State& state, const char* net) { GEMMBenchmark(state, xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, diff --git a/bench/models/BUILD b/bench/models/BUILD index fd5d9243bf03..0604e9ba0b8a 100644 --- a/bench/models/BUILD +++ b/bench/models/BUILD @@ -32,6 +32,7 @@ xnnpack_cxx_library( xnnpack_benchmark( name = "benchmark", srcs = ["benchmark.cc"], + features = ["-layering_check"], tags = xnnpack_slow_benchmark_tags(), deps = [ ":models", diff --git a/bench/models/benchmark.cc b/bench/models/benchmark.cc index 9568aacfd58a..65e8190027c3 100644 --- a/bench/models/benchmark.cc +++ b/bench/models/benchmark.cc @@ -5,12 +5,10 @@ #include -#include #include #include #include #include -#include #include #include @@ -19,9 +17,11 @@ #include "xnnpack.h" #include "xnnpack/allocator.h" #include "xnnpack/subgraph.h" +#include "src/commandlineflags.h" #include "pthreadpool.h" -int FLAGS_num_threads = 1; +BM_DEFINE_int32(num_threads, 1); +BM_DEFINE_int32(xnn_runtime_flags, 0); struct ModelRuntime { std::unique_ptr model; @@ -85,7 +85,7 @@ struct ModelRuntime { static void BenchmarkInvoke(benchmark::State& state, std::function model_factory, - uint32_t flags = 0) { + uint32_t extra_flags = 0) { if (xnn_initialize(nullptr /* allocator */) != xnn_status_success) { state.SkipWithError("failed to initialize XNNPACK"); return; @@ -98,7 +98,7 @@ static void BenchmarkInvoke(benchmark::State& state, } // TODO(dsharlet): We should have benchmarks of these steps too. - if (!model_runtime.CreateRuntime(flags)) { + if (!model_runtime.CreateRuntime(FLAGS_xnn_runtime_flags | extra_flags)) { state.SkipWithError("failed to create runtime"); return; } @@ -188,8 +188,7 @@ static void QD8Attention(benchmark::State& state) { return models::QD8Attention(state.range(0), state.range(1), state.range(2), state.range(3), state.range(4), weights); - }, - 0); + }); } static void QS8MobileNetV2(benchmark::State& state) { @@ -236,22 +235,14 @@ BENCHMARK(QD8Attention) BENCHMARK(QS8MobileNetV2)->Unit(benchmark::kMicrosecond)->UseRealTime(); -int main(int argc, char** argv) { - ::benchmark::Initialize(&argc, argv); - for (int i = 1; i < argc;) { - if (strncmp(argv[i], "--num_threads=", 14) == 0) { - FLAGS_num_threads = atoi(argv[i] + 14); - if (FLAGS_num_threads <= 0) { - std::cerr << "Invalid --num_threads: " << FLAGS_num_threads << "\n"; - return 1; - } - std::copy(argv + i + 1, argv + argc, argv + i); - argc -= 1; - } else { - ++i; - } - } - if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; - ::benchmark::RunSpecifiedBenchmarks(); +#ifdef BENCHMARK_ARGS_BOTTLENECK +// We are provided with a main that will call this function +extern "C" { +int BenchmarkArgBottleneck(int& argc, char**& argv) { + return ProcessArgs(argc, argv); +} } +#else +BENCHMARK_MAIN(); +#endif diff --git a/bench/qd8-f32-qc8w-gemm.cc b/bench/qd8-f32-qc8w-gemm.cc index 906d70576115..f6864d05cfee 100644 --- a/bench/qd8-f32-qc8w-gemm.cc +++ b/bench/qd8-f32-qc8w-gemm.cc @@ -724,6 +724,97 @@ #endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane) +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY static void qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55(benchmark::State& state, const char* net) { GEMMBenchmark(state, @@ -1348,6 +1439,306 @@ #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/5, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/6, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/9, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/10, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/11, /*nr=*/32, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/11, /*nr=*/16, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/2, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/3, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/4, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni) + + static void qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w, + /*mr=*/5, /*nr=*/64, /*kr=*/4, /*sr=*/1, + benchmark::utils::CheckAVX512VNNI); + } + + BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni) +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_AVX512VNNI && (XNN_ARCH_X86 || XNN_ARCH_X86_64) static void qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni(benchmark::State& state, const char* net) { GEMMBenchmark(state, diff --git a/bench/qp8-f32-qc8w-gemm.cc b/bench/qp8-f32-qc8w-gemm.cc new file mode 100644 index 000000000000..1d970a7635d9 --- /dev/null +++ b/bench/qp8-f32-qc8w-gemm.cc @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/qp8-f32-qc8w-gemm-minmax.yaml +// Generator: tools/generate-gemm-test.py + +#include +#include "gemm-benchmark.h" +#include "utils.h" +#include "xnnpack/common.h" +#include "xnnpack/gemm.h" +#include "xnnpack/isa-checks.h" +#include "xnnpack/microfnptr.h" +#include "xnnpack/microparams-init.h" +#include "xnnpack/pack.h" +#include "xnnpack/packw.h" + + +#if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + static void qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases, + /*mr=*/16, /*nr=*/4, /*kr=*/8, /*sr=*/1, + /*mr_packed=*/4, + benchmark::utils::CheckNEONI8MM); + } + + BENCHMARK_GEMM(qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4) + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + static void qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases, + /*mr=*/1, /*nr=*/4, /*kr=*/4, /*sr=*/1, + /*mr_packed=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot) + + static void qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases, + /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, + /*mr_packed=*/1, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot) + + static void qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases, + /*mr=*/16, /*nr=*/4, /*kr=*/4, /*sr=*/1, + /*mr_packed=*/4, + benchmark::utils::CheckNEONDOT); + } + + BENCHMARK_GEMM(qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4) + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + + +#ifndef XNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/bench/vunary.cc b/bench/vunary.cc index 5cafe1ee7778..ba679c305862 100644 --- a/bench/vunary.cc +++ b/bench/vunary.cc @@ -294,7 +294,6 @@ void vlrelu(benchmark::State& state, uint64_t arch_flags, #include "qs8-vcvt/qs8-vcvt.h" #include "qu8-f32-vcvt/qu8-f32-vcvt.h" #include "qu8-vcvt/qu8-vcvt.h" -#include "s32-f32-vcvt/s32-f32-vcvt.h" #undef XNN_CVT_UKERNEL_WITH_PARAMS #ifndef XNNPACK_BENCHMARK_NO_MAIN diff --git a/build_config/BUILD.bazel b/build_config/BUILD.bazel index 6352a43acdd4..c4052f8fa9c6 100644 --- a/build_config/BUILD.bazel +++ b/build_config/BUILD.bazel @@ -315,6 +315,21 @@ selects.config_setting_group( ], ) +selects.config_setting_group( + name = "x86_64", + match_any = [ + ":android_x86_64", + ":ios_x86_64", + ":linux_k8", + ":macos_x86_64", + ":macos_x86_64_legacy", + ":tvos_x86_64", + ":watchos_x86_64", + ":windows_x86_64_clang", + ":windows_x86_64_mingw", + ], +) + selects.config_setting_group( name = "riscv", match_any = [":linux_riscv64"], diff --git a/build_defs.bzl b/build_defs.bzl index d0e6b7bb614c..87eb54bbaac6 100644 --- a/build_defs.bzl +++ b/build_defs.bzl @@ -253,7 +253,7 @@ def xnnpack_cxx_library(name, copts = xnnpack_std_cxxopts(), gcc_copts = [], msv **kwargs ) -def xnnpack_unit_test(name, srcs, copts = [], mingw_copts = [], msys_copts = [], deps = [], tags = [], linkopts = [], defines = [], automatic = True, timeout = "short", shard_count = 1, **kwargs): +def xnnpack_unit_test(name, srcs, copts = [], mingw_copts = [], msys_copts = [], deps = [], tags = [], linkopts = [], defines = [], timeout = "short", shard_count = 1, **kwargs): """Unit test binary based on Google Test. Args: @@ -270,81 +270,45 @@ def xnnpack_unit_test(name, srcs, copts = [], mingw_copts = [], msys_copts = [], linkopts: The list of linking options defines: List of predefines macros to be added to the compile line. tags: List of arbitrary text tags. - automatic: Whether to create the test or testable binary. timeout: How long the test is expected to run before returning. shard_count: Specifies the number of parallel shards to use to run the test. **kwargs: Other arguments to pass to the cc_test rule. """ - if automatic: - native.cc_test( - name = name, - srcs = srcs, - copts = xnnpack_std_cxxopts() + [ - "-Iinclude", - "-Isrc", - ] + select({ - "//build_config:windows_x86_64_mingw": mingw_copts, - "//build_config:windows_x86_64_msys": msys_copts, - "//conditions:default": [], - }) + select({ - "//build_config:windows_x86_64_clang": ["/clang:-Wno-unused-function"], - "//build_config:windows_x86_64_mingw": ["-Wno-unused-function"], - "//build_config:windows_x86_64_msys": ["-Wno-unused-function"], - "//build_config:windows_x86_64": [], - "//conditions:default": ["-Wno-unused-function"], - }) + copts, - linkopts = select({ - "//build_config:emscripten": xnnpack_emscripten_test_linkopts(), - "//conditions:default": [], - }) + linkopts, - linkstatic = True, - defines = defines, - deps = [ - "@com_google_googletest//:gtest_main", - ] + deps + select({ - "//build_config:emscripten": xnnpack_emscripten_deps(), - "//conditions:default": [], - }), - tags = tags, - timeout = timeout, - shard_count = shard_count, - **kwargs, - ) - else: - native.cc_binary( - name = name, - srcs = srcs, - copts = xnnpack_std_cxxopts() + [ - "-Iinclude", - "-Isrc", - ] + select({ - "//build_config:windows_x86_64_mingw": mingw_copts, - "//build_config:windows_x86_64_msys": msys_copts, - "//conditions:default": [], - }) + select({ - "//build_config:windows_x86_64_clang": ["/clang:-Wno-unused-function"], - "//build_config:windows_x86_64_mingw": ["-Wno-unused-function"], - "//build_config:windows_x86_64_msys": ["-Wno-unused-function"], - "//build_config:windows_x86_64": [], - "//conditions:default": ["-Wno-unused-function"], - }) + copts, - linkopts = select({ - "//build_config:emscripten": xnnpack_emscripten_test_linkopts(), - "//conditions:default": [], - }), - linkstatic = True, - defines = defines, - deps = [ - "@com_google_googletest//:gtest_main", - ] + deps + select({ - "//build_config:emscripten": xnnpack_emscripten_deps(), - "//conditions:default": [], - }), - testonly = True, - tags = tags, - **kwargs, - ) + native.cc_test( + name = name, + srcs = srcs, + copts = xnnpack_std_cxxopts() + [ + "-Iinclude", + "-Isrc", + ] + select({ + "//build_config:windows_x86_64_mingw": mingw_copts, + "//build_config:windows_x86_64_msys": msys_copts, + "//conditions:default": [], + }) + select({ + "//build_config:windows_x86_64_clang": ["/clang:-Wno-unused-function"], + "//build_config:windows_x86_64_mingw": ["-Wno-unused-function"], + "//build_config:windows_x86_64_msys": ["-Wno-unused-function"], + "//build_config:windows_x86_64": [], + "//conditions:default": ["-Wno-unused-function"], + }) + copts, + linkopts = select({ + "//build_config:emscripten": xnnpack_emscripten_test_linkopts(), + "//conditions:default": [], + }) + linkopts, + linkstatic = True, + defines = defines, + deps = [ + "@com_google_googletest//:gtest_main", + ] + deps + select({ + "//build_config:emscripten": xnnpack_emscripten_deps(), + "//conditions:default": [], + }), + tags = tags, + timeout = timeout, + shard_count = shard_count, + **kwargs, + ) def xnnpack_binary(name, srcs, copts = [], deps = [], linkopts = []): """Minimal binary diff --git a/build_params.bzl b/build_params.bzl index d23edcd4163b..807a9fe0aafa 100644 --- a/build_params.bzl +++ b/build_params.bzl @@ -795,4 +795,21 @@ XNNPACK_PARAMS_FOR_ARCH = { ], extra_deps = [], # Extra deps for hexagon. ), + "amd64": _create_params( + cond = "//build_config:x86_64", + gcc_x86_copts = [ + "-mf16c", + "-mfma", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", + "-mgfni", + ], + msvc_x86_64_copts = ["/arch:AVX512"], + mingw_copts = ["-fno-asynchronous-unwind-tables"], + msys_copts = ["-fno-asynchronous-unwind-tables"], + ), } diff --git a/build_srcs.bzl b/build_srcs.bzl index bd75d4b55b4d..e4c38f759c94 100644 --- a/build_srcs.bzl +++ b/build_srcs.bzl @@ -13,7 +13,6 @@ OPERATOR_SRCS = [ "src/operators/average-pooling-nhwc.c", "src/operators/batch-matrix-multiply-nc.c", "src/operators/binary-elementwise-nd.c", - "src/operators/channel-shuffle-nc.c", "src/operators/constant-pad-nd.c", "src/operators/convolution-nchw.c", "src/operators/convolution-nhwc.c", @@ -108,7 +107,6 @@ XNNPACK_SRCS = [ "src/configs/x8-lut-config.c", "src/configs/xx-fill-config.c", "src/configs/xx-pad-config.c", - "src/configs/zip-config.c", ] LOGGING_SRCS = [ diff --git a/cmake/DownloadCpuinfo.cmake b/cmake/DownloadCpuinfo.cmake index 4dfff8f6f9e5..7bc767e4e2bf 100644 --- a/cmake/DownloadCpuinfo.cmake +++ b/cmake/DownloadCpuinfo.cmake @@ -15,6 +15,7 @@ IF(POLICY CMP0135) CMAKE_POLICY(SET CMP0135 NEW) ENDIF() +# LINT.IfChange INCLUDE(ExternalProject) ExternalProject_Add(cpuinfo URL https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip @@ -27,6 +28,4 @@ ExternalProject_Add(cpuinfo INSTALL_COMMAND "" TEST_COMMAND "" ) - - - +# LINT.ThenChange(../MODULE.bazel:cpuinfo) diff --git a/cmake/DownloadFXdiv.cmake b/cmake/DownloadFXdiv.cmake index e3abe405b2be..ba97ffe54bfa 100644 --- a/cmake/DownloadFXdiv.cmake +++ b/cmake/DownloadFXdiv.cmake @@ -15,6 +15,7 @@ IF(POLICY CMP0135) CMAKE_POLICY(SET CMP0135 NEW) ENDIF() +# LINT.IfChange INCLUDE(ExternalProject) ExternalProject_Add(fxdiv URL https://github.com/Maratyszcza/FXdiv/archive/b408327ac2a15ec3e43352421954f5b1967701d1.zip @@ -26,3 +27,4 @@ ExternalProject_Add(fxdiv INSTALL_COMMAND "" TEST_COMMAND "" ) +# LINT.ThenChange(../MODULE.bazel:FXdiv) diff --git a/cmake/DownloadGoogleBenchmark.cmake b/cmake/DownloadGoogleBenchmark.cmake index ba9594ecd125..ff70a3cac0cb 100644 --- a/cmake/DownloadGoogleBenchmark.cmake +++ b/cmake/DownloadGoogleBenchmark.cmake @@ -15,6 +15,7 @@ IF(POLICY CMP0135) CMAKE_POLICY(SET CMP0135 NEW) ENDIF() +# LINT.IfChange INCLUDE(ExternalProject) ExternalProject_Add(googlebenchmark URL https://github.com/google/benchmark/archive/d2a8a4ee41b923876c034afb939c4fc03598e622.zip @@ -26,3 +27,4 @@ ExternalProject_Add(googlebenchmark INSTALL_COMMAND "" TEST_COMMAND "" ) +# LINT.ThenChange(../MODULE.bazel:benchmark) diff --git a/cmake/DownloadGoogleTest.cmake b/cmake/DownloadGoogleTest.cmake index f3d133a190c6..542650ed29a1 100644 --- a/cmake/DownloadGoogleTest.cmake +++ b/cmake/DownloadGoogleTest.cmake @@ -18,8 +18,8 @@ ENDIF() # LINT.IfChange INCLUDE(ExternalProject) ExternalProject_Add(googletest - URL https://github.com/google/googletest/archive/d144031940543e15423a25ae5a8a74141044862f.zip - URL_HASH SHA256=648b9430fca63acc68c59ee98f624dcbcd9c24ea6b278c306ab6b7f49f62034a + URL https://github.com/google/googletest/archive/35d0c365609296fa4730d62057c487e3cfa030ff.zip + URL_HASH SHA256=307ccaebc77e0acd19d1d09fe856278a66d1936269a999d40accdb46ec3ab6a4 SOURCE_DIR "${CMAKE_BINARY_DIR}/googletest-source" BINARY_DIR "${CMAKE_BINARY_DIR}/googletest" CONFIGURE_COMMAND "" @@ -27,4 +27,4 @@ ExternalProject_Add(googletest INSTALL_COMMAND "" TEST_COMMAND "" ) -# LINT.ThenChange(../WORKSPACE.bazel) +# LINT.ThenChange(../MODULE.bazel:googletest) diff --git a/cmake/DownloadKleidiAI.cmake b/cmake/DownloadKleidiAI.cmake index d5fe1e8d8969..fbdd47373390 100644 --- a/cmake/DownloadKleidiAI.cmake +++ b/cmake/DownloadKleidiAI.cmake @@ -15,10 +15,11 @@ IF(POLICY CMP0135) CMAKE_POLICY(SET CMP0135 NEW) ENDIF() +# LINT.IfChange INCLUDE(ExternalProject) ExternalProject_Add(kleidiai - URL https://gitlab.arm.com/kleidi/kleidiai/-/archive/40a926833857fb64786e02f97703e42b1537cb57/kleidiai-40a926833857fb64786e02f97703e42b1537cb57.zip - URL_HASH SHA256=ad37707084a6d4ff41be10cbe8540c75bea057ba79d0de6c367c1bfac6ba0852 + URL https://gitlab.arm.com/kleidi/kleidiai/-/archive/7e8c4baf953227fa447a2f345e5d6491a504aa56/kleidiai-7e8c4baf953227fa447a2f345e5d6491a504aa56.zip + URL_HASH SHA256=8ba8cdb9f945941174d34d10eb4ad158ad1cbc1aef259de5ad992b0bbe85861f SOURCE_DIR "${CMAKE_BINARY_DIR}/kleidiai-source" BINARY_DIR "${CMAKE_BINARY_DIR}/kleidiai" CONFIGURE_COMMAND "" @@ -27,3 +28,4 @@ ExternalProject_Add(kleidiai INSTALL_COMMAND "" TEST_COMMAND "" ) +# LINT.ThenChange(../MODULE.bazel:kleidiai) diff --git a/cmake/DownloadPThreadPool.cmake b/cmake/DownloadPThreadPool.cmake index 6cb67dc00489..eda98dbdb1a8 100644 --- a/cmake/DownloadPThreadPool.cmake +++ b/cmake/DownloadPThreadPool.cmake @@ -18,8 +18,8 @@ ENDIF() # LINT.IfChange INCLUDE(ExternalProject) ExternalProject_Add(pthreadpool - URL https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip - URL_HASH SHA256=a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95 + URL https://github.com/google/pthreadpool/archive/4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0.zip + URL_HASH SHA256=6d373fa7e2b899605fc3b6e72171a71bccbaf9d4d596b7f514535c4ffb966b3b SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" BINARY_DIR "${CMAKE_BINARY_DIR}/pthreadpool" CONFIGURE_COMMAND "" @@ -27,4 +27,4 @@ ExternalProject_Add(pthreadpool INSTALL_COMMAND "" TEST_COMMAND "" ) -# LINT.ThenChange(../WORKSPACE.bazel) +# LINT.ThenChange(../MODULE.bazel:pthreadpool) diff --git a/cmake/gen/aarch64_microkernels.cmake b/cmake/gen/aarch64_microkernels.cmake index 5ab1ee69c45b..caa50bb37336 100644 --- a/cmake/gen/aarch64_microkernels.cmake +++ b/cmake/gen/aarch64_microkernels.cmake @@ -123,6 +123,7 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS src/f32-dwconv/f32-dwconv-9p4c-minmax-asm-aarch64-neonfma.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2-prfm.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2.S + src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2-prfm.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4-prfm.S @@ -134,15 +135,24 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128.S src/f32-gemm/gen/f32-gemm-1x12-minmax-asm-aarch64-neonfma-cortex-a53.S + src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld64.S src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld128.S src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75.S src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-ld64.S + src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S src/f32-gemm/gen/f32-gemm-4x12-minmax-asm-aarch64-neonfma-cortex-a53.S + src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75.S + src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S + src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-cortex-a75.S src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S @@ -228,6 +238,14 @@ SET(NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld128.S src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S diff --git a/cmake/gen/amd64_microkernels.cmake b/cmake/gen/amd64_microkernels.cmake new file mode 100644 index 000000000000..e11059c491aa --- /dev/null +++ b/cmake/gen/amd64_microkernels.cmake @@ -0,0 +1,70 @@ +# Copyright 2022 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Description: microkernel filename lists for amd64 +# +# Auto-generated file. Do not edit! +# Generator: tools/update-microkernels.py + + +SET(PROD_AMD64_ASM_MICROKERNEL_SRCS) + +SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS + src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S + src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S) + +SET(ALL_AMD64_ASM_MICROKERNEL_SRCS ${PROD_AMD64_ASM_MICROKERNEL_SRCS} + ${NON_PROD_AMD64_ASM_MICROKERNEL_SRCS}) diff --git a/cmake/gen/avx256vnni_microkernels.cmake b/cmake/gen/avx256vnni_microkernels.cmake index 4dd622cce8b6..7d8ae2aeaf01 100644 --- a/cmake/gen/avx256vnni_microkernels.cmake +++ b/cmake/gen/avx256vnni_microkernels.cmake @@ -18,6 +18,7 @@ SET(PROD_AVX256VNNI_MICROKERNEL_SRCS src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c + src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni-prfm.c src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c) SET(NON_PROD_AVX256VNNI_MICROKERNEL_SRCS @@ -113,6 +114,11 @@ SET(NON_PROD_AVX256VNNI_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c + src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni-prfm.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni-prfm.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni-prfm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x8c8-minmax-fp32-avx256vnni-prfm.c diff --git a/cmake/gen/avx2_microkernels.cmake b/cmake/gen/avx2_microkernels.cmake index 2a42a5ee2e8f..e1fa4553f096 100644 --- a/cmake/gen/avx2_microkernels.cmake +++ b/cmake/gen/avx2_microkernels.cmake @@ -75,7 +75,6 @@ SET(PROD_AVX2_MICROKERNEL_SRCS src/qu8-vcvt/gen/qu8-vcvt-avx2-u32.c src/qu8-vlrelu/gen/qu8-vlrelu-avx2-u32.c src/s8-vclamp/s8-vclamp-avx2-u128.c - src/s32-f32-vcvt/gen/s32-f32-vcvt-avx2.c src/u8-vclamp/u8-vclamp-avx2-u128.c src/x8-lut/gen/x8-lut-avx2-u128.c src/x8-transposec/gen/x8-transposec-32x32-reuse-switch-avx2.c diff --git a/cmake/gen/avx512f_microkernels.cmake b/cmake/gen/avx512f_microkernels.cmake index 9d2e9aead069..5a445b2dde0e 100644 --- a/cmake/gen/avx512f_microkernels.cmake +++ b/cmake/gen/avx512f_microkernels.cmake @@ -64,7 +64,6 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS src/f32-vunary/gen/f32-vabs-avx512f.c src/f32-vunary/gen/f32-vneg-avx512f.c src/f32-vunary/gen/f32-vsqr-avx512f.c - src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u8.c src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c) diff --git a/cmake/gen/avxvnni_microkernels.cmake b/cmake/gen/avxvnni_microkernels.cmake index 2b6d5189046a..fbb7d3451388 100644 --- a/cmake/gen/avxvnni_microkernels.cmake +++ b/cmake/gen/avxvnni_microkernels.cmake @@ -132,8 +132,14 @@ SET(NON_PROD_AVXVNNI_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c + src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni-prfm.c + src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni-prfm.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni-prfm.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avxvnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni-prfm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni.c diff --git a/cmake/gen/hvx_microkernels.cmake b/cmake/gen/hvx_microkernels.cmake index 34b5aad873c2..f04f6ee127d8 100644 --- a/cmake/gen/hvx_microkernels.cmake +++ b/cmake/gen/hvx_microkernels.cmake @@ -106,6 +106,7 @@ SET(NON_PROD_HVX_MICROKERNEL_SRCS src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u32.c src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u64.c src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u96.c - src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u128.c) + src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u128.c + src/x32-packw/gen/x32-packw-gio-hvx-u2.c) SET(ALL_HVX_MICROKERNEL_SRCS ${PROD_HVX_MICROKERNEL_SRCS} + ${NON_PROD_HVX_MICROKERNEL_SRCS}) diff --git a/cmake/gen/microkernels.cmake b/cmake/gen/microkernels.cmake index 6c42bc363421..bd8dcb7b7326 100644 --- a/cmake/gen/microkernels.cmake +++ b/cmake/gen/microkernels.cmake @@ -10,6 +10,7 @@ INCLUDE(cmake/gen/aarch32_microkernels.cmake) INCLUDE(cmake/gen/aarch64_microkernels.cmake) +INCLUDE(cmake/gen/amd64_microkernels.cmake) INCLUDE(cmake/gen/armsimd32_microkernels.cmake) INCLUDE(cmake/gen/avx_microkernels.cmake) INCLUDE(cmake/gen/avx2_microkernels.cmake) diff --git a/cmake/gen/neon_microkernels.cmake b/cmake/gen/neon_microkernels.cmake index 3d789e69f57e..6e714381fcb5 100644 --- a/cmake/gen/neon_microkernels.cmake +++ b/cmake/gen/neon_microkernels.cmake @@ -145,17 +145,12 @@ SET(PROD_NEON_MICROKERNEL_SRCS src/s8-ibilinear/gen/s8-ibilinear-neon-c16.c src/s8-maxpool/s8-maxpool-9p8x-minmax-neon-c16.c src/s8-vclamp/s8-vclamp-neon-u64.c - src/s32-f32-vcvt/gen/s32-f32-vcvt-neon.c src/u8-ibilinear/gen/u8-ibilinear-neon-c8.c src/u8-ibilinear/gen/u8-ibilinear-neon-c16.c src/u8-maxpool/u8-maxpool-9p8x-minmax-neon-c16.c src/u8-rmax/u8-rmax-neon-u16.c src/u8-vclamp/u8-vclamp-neon-u64.c src/x8-transposec/gen/x8-transposec-16x16-reuse-dec-zip-neon.c - src/x8-zip/x8-zip-x2-neon.c - src/x8-zip/x8-zip-x3-neon.c - src/x8-zip/x8-zip-x4-neon.c - src/x8-zip/x8-zip-xm-neon.c src/x16-packw/gen/x16-packw-x8-gemm-goi-neon-ld4lane-u8-prfm.c src/x16-packw/gen/x16-packw-x16-gemm-goi-neon-ld4lane-u8-prfm.c src/x16-transposec/gen/x16-transposec-8x8-reuse-dec-zip-neon.c @@ -165,10 +160,6 @@ SET(PROD_NEON_MICROKERNEL_SRCS src/x32-packw/gen/x32-packw-x8s4-gemm-goi-neon-ld4lane-u4-prfm.c src/x32-transposec/gen/x32-transposec-4x4-reuse-dec-zip-neon.c src/x32-unpool/x32-unpool-neon.c - src/x32-zip/x32-zip-x2-neon.c - src/x32-zip/x32-zip-x3-neon.c - src/x32-zip/x32-zip-x4-neon.c - src/x32-zip/x32-zip-xm-neon.c src/x64-transposec/gen/x64-transposec-2x2-multi-dec-zip-neon.c src/x64-transposec/gen/x64-transposec-2x2-reuse-dec-zip-neon.c src/xx-fill/xx-fill-neon-u64.c @@ -829,6 +820,7 @@ SET(NON_PROD_NEON_MICROKERNEL_SRCS src/x16-transposec/gen/x16-transposec-8x8-reuse-mov-zip-neon.c src/x16-transposec/gen/x16-transposec-8x8-reuse-multi-zip-neon.c src/x16-transposec/gen/x16-transposec-8x8-reuse-switch-zip-neon.c + src/x32-packw/gen/x32-packw-gio-neon-u2.c src/x32-packw/gen/x32-packw-x2-gemm-goi-neon-ld2lane-u2.c src/x32-packw/gen/x32-packw-x8-gemm-goi-neon-ld4lane-u4.c src/x32-packw/gen/x32-packw-x8-gemm-goi-neon-ld4lane-u8-prfm.c diff --git a/cmake/gen/neondot_aarch64_microkernels.cmake b/cmake/gen/neondot_aarch64_microkernels.cmake index 0d54cdc6bfcc..8fc0ff78ab77 100644 --- a/cmake/gen/neondot_aarch64_microkernels.cmake +++ b/cmake/gen/neondot_aarch64_microkernels.cmake @@ -12,7 +12,10 @@ SET(PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c - src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c) + src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c + src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c4-aarch64-neondot.c + src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c8-aarch64-neondot.c + src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c4-mstep4-aarch64-neondot.c) SET(NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-aarch64-neondot-ld128.c diff --git a/cmake/gen/neoni8mm_microkernels.cmake b/cmake/gen/neoni8mm_microkernels.cmake index 0d0fab7fa07b..46417ff6fc6a 100644 --- a/cmake/gen/neoni8mm_microkernels.cmake +++ b/cmake/gen/neoni8mm_microkernels.cmake @@ -28,6 +28,7 @@ SET(PROD_NEONI8MM_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c8-minmax-neoni8mm.c src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-16x4c16s2-mstep4-neoni8mm.c src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-8x8c16s2-mstep2-neoni8mm.c + src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c8-mstep4-neoni8mm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-neoni8mm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-neoni8mm.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-neoni8mm.c diff --git a/cmake/gen/neonsme2_microkernels.cmake b/cmake/gen/neonsme2_microkernels.cmake index 53d3e965f188..64c683176c09 100644 --- a/cmake/gen/neonsme2_microkernels.cmake +++ b/cmake/gen/neonsme2_microkernels.cmake @@ -10,6 +10,7 @@ SET(PROD_NEONSME2_MICROKERNEL_SRCS + src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c src/x32-pack-lh/x32-packlh-neonsme2.c) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index a07d6791fc05..a394b9f01ed4 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -156,9 +156,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u1.c src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u4.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c - src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c - src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c - src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-lrintf.c @@ -222,7 +219,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/s8-ibilinear/gen/s8-ibilinear-scalar-c1.c src/s8-maxpool/s8-maxpool-9p8x-minmax-scalar-c1.c src/s8-vclamp/s8-vclamp-scalar-u4.c - src/s32-f32-vcvt/gen/s32-f32-vcvt-scalar.c src/u8-ibilinear/gen/u8-ibilinear-scalar-c1.c src/u8-lut32norm/u8-lut32norm-scalar.c src/u8-maxpool/u8-maxpool-9p8x-minmax-scalar-c1.c @@ -235,10 +231,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/x8-packw/gen/x8-packw-x16-gemm-goi-scalar-u2.c src/x8-packw/gen/x8-packw-x32-gemm-goi-scalar-u2.c src/x8-transposec/gen/x8-transposec-2x4-scalar-int.c - src/x8-zip/x8-zip-x2-scalar.c - src/x8-zip/x8-zip-x3-scalar.c - src/x8-zip/x8-zip-x4-scalar.c - src/x8-zip/x8-zip-xm-scalar.c src/x16-packw/gen/x16-packw-x64-gemm-goi-scalar-int-u4.c src/x16-transposec/gen/x16-transposec-2x4-scalar-int.c src/x24-transposec/gen/x24-transposec-1x2-scalar.c @@ -246,10 +238,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/x32-packw/gen/x32-packw-x4-gemm-goi-scalar-float-u4.c src/x32-transposec/gen/x32-transposec-2x4-scalar-int.c src/x32-unpool/x32-unpool-scalar.c - src/x32-zip/x32-zip-x2-scalar.c - src/x32-zip/x32-zip-x3-scalar.c - src/x32-zip/x32-zip-x4-scalar.c - src/x32-zip/x32-zip-xm-scalar.c src/x64-transposec/gen/x64-transposec-4x2-scalar-int.c src/xx-copy/xx-copy-scalar-memcpy.c src/xx-fill/xx-fill-scalar-u16.c @@ -621,6 +609,9 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c + src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c + src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x32c8-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c diff --git a/cmake/gen/sse2_microkernels.cmake b/cmake/gen/sse2_microkernels.cmake index a6d675da34ae..0615ceeec0cf 100644 --- a/cmake/gen/sse2_microkernels.cmake +++ b/cmake/gen/sse2_microkernels.cmake @@ -86,18 +86,10 @@ SET(PROD_SSE2_MICROKERNEL_SRCS src/u8-rmax/u8-rmax-sse2-u16.c src/u8-vclamp/u8-vclamp-sse2-u64.c src/x8-transposec/gen/x8-transposec-16x16-reuse-mov-sse2.c - src/x8-zip/x8-zip-x2-sse2.c - src/x8-zip/x8-zip-x3-sse2.c - src/x8-zip/x8-zip-x4-sse2.c - src/x8-zip/x8-zip-xm-sse2.c src/x16-transposec/gen/x16-transposec-8x8-reuse-multi-sse2.c src/x32-packw/gen/x32-packw-x2c4-gemm-goi-sse2-u4.c src/x32-packw/gen/x32-packw-x8-gemm-goi-sse2-u4.c src/x32-unpool/x32-unpool-sse2.c - src/x32-zip/x32-zip-x2-sse2.c - src/x32-zip/x32-zip-x3-sse2.c - src/x32-zip/x32-zip-x4-sse2.c - src/x32-zip/x32-zip-xm-sse2.c src/x64-transposec/gen/x64-transposec-2x2-multi-mov-sse2.c src/xx-fill/xx-fill-sse2-u64.c src/xx-pad/xx-pad-p16-sse2-u16.c) diff --git a/cmake/gen/sse41_microkernels.cmake b/cmake/gen/sse41_microkernels.cmake index 4bc04de16852..722aa1840a6e 100644 --- a/cmake/gen/sse41_microkernels.cmake +++ b/cmake/gen/sse41_microkernels.cmake @@ -351,6 +351,7 @@ SET(NON_PROD_SSE41_MICROKERNEL_SRCS src/qu8-vmul/gen/qu8-vmul-minmax-fp32-sse41-mul16-ld64-u8.c src/qu8-vmulc/gen/qu8-vmulc-minmax-fp32-sse41-mul16-ld64-u8.c src/s8-ibilinear/gen/s8-ibilinear-sse41-c8.c - src/u8-ibilinear/gen/u8-ibilinear-sse41-c8.c) + src/u8-ibilinear/gen/u8-ibilinear-sse41-c8.c + src/x32-packw/gen/x32-packw-gio-sse41-u2.c) SET(ALL_SSE41_MICROKERNEL_SRCS ${PROD_SSE41_MICROKERNEL_SRCS} + ${NON_PROD_SSE41_MICROKERNEL_SRCS}) diff --git a/cmake/gen/wasmsimd_microkernels.cmake b/cmake/gen/wasmsimd_microkernels.cmake index 9807847260b2..2d93b7d9717d 100644 --- a/cmake/gen/wasmsimd_microkernels.cmake +++ b/cmake/gen/wasmsimd_microkernels.cmake @@ -206,7 +206,6 @@ SET(PROD_WASMSIMD_MICROKERNEL_SRCS src/s8-ibilinear/gen/s8-ibilinear-wasmsimd-dot16x2-c8.c src/s8-maxpool/s8-maxpool-9p8x-minmax-wasmsimd-c16.c src/s8-vclamp/s8-vclamp-wasmsimd-u64.c - src/s32-f32-vcvt/gen/s32-f32-vcvt-wasmsimd.c src/u8-ibilinear/gen/u8-ibilinear-wasmsimd-dot16x2-c8.c src/u8-maxpool/u8-maxpool-9p8x-minmax-wasmsimd-c16.c src/u8-vclamp/u8-vclamp-wasmsimd-u64.c @@ -217,10 +216,6 @@ SET(PROD_WASMSIMD_MICROKERNEL_SRCS src/x32-packw/gen/x32-packw-x8-gemm-goi-wasmsimd-u4.c src/x32-transposec/gen/x32-transposec-4x4-reuse-mov-wasmsimd.c src/x32-unpool/x32-unpool-wasmsimd.c - src/x32-zip/x32-zip-x2-wasmsimd.c - src/x32-zip/x32-zip-x3-wasmsimd.c - src/x32-zip/x32-zip-x4-wasmsimd.c - src/x32-zip/x32-zip-xm-wasmsimd.c src/xx-fill/xx-fill-wasmsimd-u64.c src/xx-pad/xx-pad-p16-wasmsimd-u16.c) @@ -1016,6 +1011,7 @@ SET(NON_PROD_WASMSIMD_MICROKERNEL_SRCS src/x16-transposec/gen/x16-transposec-8x8-multi-switch-wasmsimd.c src/x16-transposec/gen/x16-transposec-8x8-reuse-multi-wasmsimd.c src/x16-transposec/gen/x16-transposec-8x8-reuse-switch-wasmsimd.c + src/x32-packw/gen/x32-packw-gio-wasmsimd-u2.c src/x32-packw/gen/x32-packw-x8s4-gemm-goi-wasmsimd-u4.c src/x32-packx/x32-packx-4x-wasmsimd.c src/x32-transposec/gen/x32-transposec-4x4-multi-mov-wasmsimd.c diff --git a/gemm_compiler/BUILD b/gemm_compiler/BUILD new file mode 100644 index 000000000000..1809c92142ea --- /dev/null +++ b/gemm_compiler/BUILD @@ -0,0 +1,80 @@ +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@rules_python//python:py_binary.bzl", "py_binary") +load("@rules_python//python:py_library.bzl", "py_library") + +py_binary( + name = "generate_gemm_microkernels_main", + srcs = [ + "generate_gemm_microkernels_main.py", + ], + main = "generate_gemm_microkernels_main.py", + tags = [ + "notap", + ], + deps = [ + ":generate_gemm_microkernels", + ], +) + +py_library( + name = "generate_gemm_microkernels", + srcs = [ + "generate.py", + "generate_f32_gemm_microkernels.py", + "generate_qd8_f32_qc8w_gemm_microkernels.py", + ], + deps = [ + ":aarch64_arch_template", + ":aarch64_isa_templates", + ":x64_arch_template", + ":x64_isa_templates", + ], +) + +py_library( + name = "x64_isa_templates", + srcs = [ + "avx512f_template.py", + "avx512vnni_template.py", + "fma3_template.py", + ], +) + +py_library( + name = "x64_arch_template", + srcs = [ + "x64_template.py", + ], + deps = [ + ":base_architecture", + ], +) + +py_library( + name = "aarch64_isa_templates", + srcs = [ + "neondot_template.py", + "neonfma_template.py", + ], +) + +py_library( + name = "aarch64_arch_template", + srcs = [ + "aarch64_template.py", + ], + deps = [ + ":base_architecture", + ], +) + +py_library( + name = "base_architecture", + srcs = [ + "base_architecture.py", + ], +) diff --git a/gemm_compiler/aarch64_template.py b/gemm_compiler/aarch64_template.py new file mode 100644 index 000000000000..2ef7ed28cd45 --- /dev/null +++ b/gemm_compiler/aarch64_template.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import base_architecture as base_architecture + +"""All non SIMD features for aarch64.""" + + +class Aarch64(base_architecture.BaseArchitecture): + + def astride_register(self): + return 'x4' + + def kc_register(self): + return 'x2' + + def k_register(self): + return 'x20' + + def cm_stride_register(self): + return 'x7' + + def am_registers(self): + return [self.a_ptr_register()] + ['x9', 'x10', 'x11', 'x12', 'x21', 'x22'] + + def a_ptr_register(self): + return 'x3' + + def c_ptr_register(self): + return 'x6' + + def cm_registers(self): + return [self.c_ptr_register()] + ['x13', 'x14', 'x15', 'x19', 'x23', 'x24'] + + def w_ptr_register(self): + return 'x5' + + def min_register(self): + return 'v0' + + def max_register(self): + return 'v1' + + def nc_register(self): + return 'x1' + + def mr_register(self): + return 'x0' + + def tmp_gp_registers(self): + return ['x22', 'x23'] + + def dequantize(self, M, N, W): + return '' + + def adjust_kc(self): + return '' + + def register_map_byte(self, reg): + map = { + 'x0': 'x0', + 'x1': 'x1', + 'x2': 'x2', + 'x3': 'x3', + 'x4': 'x4', + 'x5': 'x5', + 'x6': 'x6', + 'x7': 'x7', + 'x8': 'x8', + 'x9': 'x9', + 'x10': 'x10', + 'x11': 'x11', + 'x12': 'x12', + 'x13': 'x13', + 'x14': 'x10', + 'x15': 'x15', + } + return map[reg] + + def register_map_dword(self, reg): + map = { + 'x0': 'q0', + 'x1': 'q1', + 'x2': 'q2', + 'x3': 'q3', + 'x4': 'q4', + 'x5': 'q5', + 'x6': 'q6', + 'x7': 'q7', + 'x8': 'q8', + 'x9': 'q9', + 'x10': 'q10', + 'x11': 'q11', + 'x12': 'q12', + 'x13': 'q13', + 'x14': 'q10', + 'x15': 'q15', + } + return map[reg] + + def function_name(self, M, N, isa): + return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_aarch64_{isa}_lane\n' + + def quantization_params(self): + return '' + + def header(self, M, N, prefix, isa): + HEADER = '#include "xnnpack/assembly.h"\n\n' + + HEADER += 'BEGIN_FUNCTION ' + self.function_name(M, N, isa) + HEADER += """ + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13]\n""" + HEADER += self.quantization_params() + return HEADER + + def jump_to_label(self, label): + return f'b {label}\n' + + def read_a_registers(self, M): + return '' + + def inner_loop(self, M, N): + N_COUNT = N // self.n_step() + asm_string = '\ninner_loop:\n' + if 'before' in self.input_asm(): + asm_string += self.input_asm()['before'] + for mr in range(0, M): + for l in self.input_asm()['loop']: + asm_string += l.format( + AM_ptr=self.am_registers()[mr], + AM=self.a_registers(mr), + a_offset=self.k_register(), + ) + if 'after' in self.input_asm(): + asm_string += self.input_asm()['after'] + + # weights + if 'before' in self.weights_asm(): + asm_string += self.weights_asm()['before'] + for l in self.weights_asm()['loop_2']: + for nr in range(0, N_COUNT, 2): + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + W_1=self.w_registers()[nr + 1], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + for l in self.weights_asm()['loop']: + if N_COUNT % 2 != 0: + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + if 'after' in self.weights_asm(): + asm_string += self.weights_asm()['after'].format( + W=self.w_ptr_register(), w_step=self.register_bytes() * N_COUNT + ) + + for l in self.compute_asm()['loop']: + for nr in range(0, N_COUNT): + for mr in range(0, M): + asm_string += l.format( + W=self.w_registers()[nr], + A=self.a_registers(mr), + ACC=self.acc_registers()[M * nr + mr], + ) + return asm_string + + def outer_loop_prepare(self, M, N): + return '' + + def input_output_register_setup(self, M): + registers = self.am_registers() + a_stride = self.astride_register() + c_stride = self.cm_stride_register() + a_base_ptr = self.a_ptr_register() + c_base_ptr = self.c_ptr_register() + # setup a{0}->a{M-1} registers + if M == 1: + return '' + asm_string = '# Setup and alias a & c pointers.\n' + asm_string += self.input_output_strides( + M=M, registers=self.am_registers(), stride=self.astride_register() + ) + + # setup c{0}->c{M-1} registers + asm_string += self.input_output_strides( + M=M, registers=self.cm_registers(), stride=self.cm_stride_register() + ) + + # Pre outer loop preparation + # asm_string += isa.outer_loop_prepare(M=M, N=N_COUNT, W=w_ptr_reg, accumulators=acc_registers) + + # if mr < MR + clamp_string, outer = self.clamp_inputs_and_outputs( + M, self.labels(), self.am_registers(), self.cm_registers() + ) + asm_string += clamp_string + return asm_string + + def input_output_strides(self, M, registers, stride): + INPUT_OUTPUT_REGISTER_SETUP = """add {aM}, {aM_1}, {STRIDE}\n""" + ret = '' + for mr in range(1, M): + ret += INPUT_OUTPUT_REGISTER_SETUP.format( + M=mr, + M_1=mr - 1, + aM=registers[mr], + aM_1=registers[mr - 1], + STRIDE=stride, + ) + return ret + + def clamp_inputs_and_outputs( + self, M, labels, input_registers, output_registers + ): + clamping = { + 'clamp': """ + cmp {mr_reg}, {M} + csel {AM_1}, {AM_0}, {AM_1}, LO + csel {CM_1}, {CM_0}, {CM_1}, LO + csel {AM_2}, {AM_1}, {AM_2}, LS + csel {CM_2}, {CM_1}, {CM_2}, LS\n""", + } + ret = '' + outer = M + # clamp a & c + end_index = M if (M % 2 == 1) else M - 1 + for mr in range(2, end_index, 2): + ret += clamping['clamp'].format( + mr_reg=self.mr_register(), + AM_0=input_registers[mr - 2], + AM_1=input_registers[mr - 1], + AM_2=input_registers[mr], + CM_0=output_registers[mr - 2], + CM_1=output_registers[mr - 1], + CM_2=output_registers[mr], + M=mr, + ) + if end_index != M: + ret += """ + cmp {mr_reg}, {M} + csel {AM_1}, {AM_0}, {AM_1}, LO + csel {CM_1}, {CM_0}, {CM_1}, LO\n""".format( + mr_reg=self.mr_register(), + AM_0=input_registers[end_index - 1], + AM_1=input_registers[end_index], + CM_0=output_registers[end_index - 1], + CM_1=output_registers[end_index], + M=end_index + 1, + ) + + return ret, outer + + def increment_ptr(self, ptr, step): + return f'add {ptr}, {ptr}, {step}\n' + + def zero_gp_register(self, reg): + return f'eor {reg}, {reg}, {reg}\n' + + def cmp_k_and_jump_if_less(self, label): + kc_register = self.kc_register() + k_register = self.k_register() + return """add {k_register}, {k_register}, 4 + cmp {kc_register}, {k_register} + bne {label}\n""".format( + label=label, k_register=k_register, kc_register=kc_register + ) + + def epilogue(self, M, N, isa): + restore_stack = """ +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION {function_name}""".format( + M=M, N=N, function_name=isa.function_name(M, N, isa.isa()) + ) + return restore_stack diff --git a/gemm_compiler/avx512f_template.py b/gemm_compiler/avx512f_template.py new file mode 100644 index 000000000000..f8f941b32ec5 --- /dev/null +++ b/gemm_compiler/avx512f_template.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import fma3_template as isa +from gemm_compiler import x64_template as arch + +"""All SIMD features for avx512f.""" + + +class Avx512F(isa.Fma3): + + def __init__(self): + pass # Empty constructor + + def isa(self): + return 'avx512f' + + def register_bytes(self): + return 64 + + def prefix(self): + return 'z' + + def a_registers(self, idx): + registers = ['zmm2', 'zmm3', 'zmm4', 'zmm5', 'zmm6'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['zmm10', 'zmm11', 'zmm12', 'zmm13'] + + def n_step(self): + return 16 + + def input_asm(self): + res = super().input_asm() + res['compute'] = self.compute_asm()['loop'] + return res + + def dequantize(self, M, N, W): + return '' + + def adjust_kc(self): + return '' + + def compute_asm(self): + c_asm = { + 'loop': ['vfmadd231ps z{ACC}, {A}, {W}\n'], + } + return c_asm + + def outer_loop_prepare(self, M, N): + return '' + + def inner_loop_spill_gp(self, M, N): + asm_string = '\ninner_loop:\n' + N_COUNT = N // self.n_step() + # weights + if 'before' in self.weights_asm(): + asm_string += self.weights_asm()['before'] + for l in self.weights_asm()['loop']: + for nr in range(0, N_COUNT): + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + + # input + if 'before' in self.input_asm(): + asm_string += self.input_asm()['before'] + if 'after' in self.input_asm(): + asm_string += self.input_asm()['after'] + if 'after' in self.weights_asm(): + asm_string += self.weights_asm()['after'].format( + W=self.w_ptr_register(), w_step=self.register_bytes() * N_COUNT + ) + + for mr in range(0, M): + for l in self.input_asm()['loop']: + asm_string += l.format( + AM_ptr=self.am_registers()[mr], + AM=self.a_registers(0), + a_offset=self.k_register(), + W=self.w_registers()[nr], + A=self.a_registers(0), + ACC=self.acc_registers()[M * nr + mr], + ) + for m in self.input_asm()['compute']: + for nr in range(0, N_COUNT): + asm_string += m.format( + W=self.w_registers()[nr], + A=self.a_registers(0), + ACC=self.acc_registers()[M * nr + mr], + ) + return asm_string + + def inner_loop_small_M_N(self, M, N): + N_COUNT = N // self.n_step() + asm_string = '\ninner_loop:\n' + # input + if 'before' in self.input_asm(): + asm_string += self.input_asm()['before'] + if 'after' in self.input_asm(): + asm_string += self.input_asm()['after'] + + # weights + if 'before' in self.weights_asm(): + asm_string += self.weights_asm()['before'] + for l in self.weights_asm()['loop']: + for nr in range(0, N_COUNT): + asm_string += l.format( + W_ptr=self.w_ptr_register(), + W=self.w_registers()[nr], + offset=self.register_bytes() * nr, + w_step=self.register_bytes() * N_COUNT, + ) + if 'after' in self.weights_asm(): + asm_string += self.weights_asm()['after'].format( + W=self.w_ptr_register(), w_step=self.register_bytes() * N_COUNT + ) + + for mr in range(0, M): + for l in self.input_asm()['loop']: + asm_string += l.format( + AM_ptr=self.am_registers()[mr], + AM=self.a_registers(mr), + a_offset=self.k_register(), + W=self.w_registers()[nr], + A=self.a_registers(mr), + ACC=self.acc_registers()[M * nr + mr], + ) + for m in self.input_asm()['compute']: + for nr in range(0, N_COUNT): + asm_string += m.format( + W=self.w_registers()[nr], + A=self.a_registers(mr), + ACC=self.acc_registers()[M * nr + mr], + ) + return asm_string + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with the biases.\n' + W = self.w_ptr_register() + accumulators = self.acc_registers() + bias = 'vmovaps z{ACC}, [{W} + {offset}]\n' + for nr in range(0, N): + ret += bias.format( + W=W, ACC=accumulators[nr * M], offset=self.register_bytes() * nr + ) + for nr in range(0, N): + for mr in range(1, M): + ret += self.copy_simd_register( + prefix=self.prefix(), + src=accumulators[M * nr], + dst=accumulators[M * nr + mr], + ) + return ret + + def copy_simd_register(self, prefix, src, dst): + return f'vmovaps {prefix}{dst}, {prefix}{src}\n' + + def store( + self, + M, + N, + ): + tmp_gp_regs = self.tmp_gp_registers() + accumulators = self.acc_registers() + cm_registers = self.cm_registers() + nc_reg = self.nc_register() + nc_lo = self.register_map_byte(nc_reg) + pop_c = M > self.max_M_before_spilling() + N_COUNT = N // self.n_step() + asm_string = '' + c_reg_offset = self.max_M_before_spilling() + if pop_c: + asm_string += '\n' + '# Pop output pointers from the stack.\n' + c_reg_offset = 0 + POP_C = 'mov {C_REG}, [rsp - {offset}]\n' + for mr in range(0, M): + sp_offset = 128 + (mr) * 16 + 8 + asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset) + asm_string += """ + # Check whether full or partial store. + cmp {nc}, {n_step} + jl tail\n""".format(n_step=N, N_2=N // 2, nc=nc_reg) + for mr in range(0, M): + asm_string += """ + vmovups [{c_reg}], z{ACC}""".format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + for nr in range(1, N_COUNT): + asm_string += """ + vmovups [{c_reg} + {offset}], z{ACC}""".format( + ACC=accumulators[M * nr + mr], + c_reg=cm_registers[mr + c_reg_offset], + offset=self.register_bytes() * nr, + ) + asm_string += '\n' + for mr in range(0, M): + asm_string += 'add {cm}, {cn_stride}\n'.format( + cn_stride=N_COUNT * 64, cm=cm_registers[mr + c_reg_offset] + ) + if pop_c: + asm_string += '\n' + '# Write output pointers to the stack.\n' + POP_C = 'mov [rsp - {offset}], {C_REG}\n' + for mr in range(0, M): + sp_offset = 128 + (mr) * 16 + 8 + asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset) + CHECK = """ + sub {nc}, {n_step} + jne outer_loop + jmp return\n""".format(n_step=N, nc=nc_reg) + asm_string += CHECK + + asm_string += '\ntail:' + if N == 64: + asm_string += """ + mov {tmp1}, -1 + sal {tmp1}, {nc_lo} + not {tmp1} + kmovw k1, {tmp1_lo} + shr {tmp1}, 16 + kmovw k2, {tmp1_lo} + shr {tmp1}, 16 + kmovw k3, {tmp1_lo} + shr {tmp1}, 16 + kmovw k4, {tmp1_lo}\n + """.format( + nc_reg=nc_reg, + tmp1=tmp_gp_regs[1], + tmp1_lo=self.register_map_dword(tmp_gp_regs[1]), + nc_lo='cl', + ACC=accumulators[0], + c_reg=cm_registers[0], + ) + for mr in range(0, M): + asm_string += 'vmovups ZMMWORD PTR [{c_reg}]{{k1}}, z{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 64]{{k2}}, z{ACC}\n'.format( + ACC=accumulators[mr + M], c_reg=cm_registers[mr + c_reg_offset] + ) + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 128]{{k3}}, z{ACC}\n'.format( + ACC=accumulators[mr + 2 * M], + c_reg=cm_registers[mr + c_reg_offset], + ) + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 192]{{k4}}, z{ACC}\n'.format( + ACC=accumulators[mr + 3 * M], + c_reg=cm_registers[mr + c_reg_offset], + ) + ) + elif N == 32: + asm_string += """ + mov {tmp1_lo}, -1 + sal {tmp1_lo}, {nc_lo} + not {tmp1_lo} + kmovw k1, {tmp1_lo} + shr {tmp1_lo}, 16 + kmovw k2, {tmp1_lo}\n""".format( + nc_reg=nc_reg, + tmp1_lo=self.register_map_dword(tmp_gp_regs[1]), + nc_lo='cl', + ACC=accumulators[0], + c_reg=cm_registers[0], + ) + for mr in range(0, M): + asm_string += 'vmovups ZMMWORD PTR [{c_reg}]{{k1}}, z{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + asm_string += ( + 'vmovups ZMMWORD PTR [{c_reg} + 64]{{k2}}, z{ACC}\n'.format( + ACC=accumulators[mr + M], c_reg=cm_registers[mr + c_reg_offset] + ) + ) + else: + asm_string += """ + mov {tmp1_lo}, -1 + sal {tmp1_lo}, {nc_lo} + not {tmp1_lo} + kmovw k1, {tmp1_lo}\n""".format( + nc_reg=nc_reg, + tmp1_lo=self.register_map_dword(tmp_gp_regs[1]), + nc_lo='cl', + ACC=accumulators[0], + c_reg=cm_registers[0 + c_reg_offset], + ) + for mr in range(0, M): + asm_string += 'vmovups ZMMWORD PTR [{c_reg}]{{k1}}, z{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr + c_reg_offset] + ) + + return asm_string diff --git a/gemm_compiler/avx512vnni_template.py b/gemm_compiler/avx512vnni_template.py new file mode 100644 index 000000000000..a7da45726b4f --- /dev/null +++ b/gemm_compiler/avx512vnni_template.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import avx512f_template as isa + +"""All SIMD features for avx512vnni.""" + + +class Avx512Vnni(isa.Avx512F): + + def isa(self): + return 'avx512vnni' + + def a_registers(self, idx): + return 'zmm2' + + def scale_registers(self): + return ['zmm10', 'zmm11', 'zmm2', 'zmm3'] + + def w_registers(self): + return ['zmm6', 'zmm7', 'zmm8', 'zmm9'] + + def acc_registers(self): + return [ + 'mm5', + 'mm12', + 'mm14', + 'mm15', + 'mm16', + 'mm17', + 'mm18', + 'mm19', + 'mm20', + 'mm21', + 'mm22', + 'mm23', + 'mm24', + 'mm25', + 'mm26', + 'mm27', + 'mm28', + 'mm29', + 'mm30', + 'mm4', + 'mm8', + 'mm9', + ] + + def function_name(self, M, N, isa): + return f'xnn_qd8_f32_qc8w_gemm_minmax_ukernel_{M}x{N}c4__asm_amd64_{isa}\n' + + def zp_scale(self, pos): + regs = ['10', '11'] + return regs[pos] + + # kc = round_up_po2(kc, channels) + def adjust_kc(self): + channels = 4 + ret = """ + add {kc_reg}, {channels} + and {kc_reg}, {neg_channels}\n""".format( + kc_reg=self.kc_register(), channels=channels - 1, neg_channels=-channels + ) + return ret + + def quantization_params(self): + return """ + mov {quantization_params_reg}, [rsp + 88] + """.format(quantization_params_reg=self.quantization_params_register()) + + def quantization_params_register(self): + return self.k_register() + + def input_asm(self): + in_asm = { + 'loop': [ + 'vpbroadcastd {AM}, [{AM_ptr} + {a_offset}]\n', + ], + 'compute': ['vpdpbusd z{ACC}, {A}, {W}\n'], + } + return in_asm + + def weights_asm(self): + w_asm = { + 'loop': [ + 'vmovaps {W}, [{W_ptr} + {offset}]\n', + ], + 'after': 'add {W}, {w_step}\n', + } + return w_asm + + def compute_asm(self): + c_asm = { + 'loop': ['vpdpbusd z{ACC}, {A}, {W}\n'], + } + return c_asm + + def dequantize(self, M, N, W): + accumulators = self.acc_registers() + ret = '' + ret += '\n# Convert from int32 to float.\n' + for nr in range(0, N * M): + ret += 'vcvtdq2ps z{ACC}, z{ACC}\n'.format(ACC=accumulators[nr]) + ret += '# Load quantization_params pointer from stack\n' + ret += 'mov {quantization_params_reg}, [rsp + {offset}]\n'.format( + quantization_params_reg=self.quantization_params_register(), + offset=self.stack_size(M) + 88, + ) + for nr in range(0, N): + for mr in range(0, M): + ret += ( + 'vmulps z{ACC}, z{ACC}, DWORD PTR [{quantization_params_reg} +' + ' {offset}]{{1to16}}\n'.format( + ACC=accumulators[nr * M + mr], + offset=4 + mr * 8, + quantization_params_reg=self.quantization_params_register(), + ) + ) + output_scale = 'vmovaps {W_SCALE}, [{W} + {offset}]\n' + # output scales + for nr in range(0, N): + ret += output_scale.format( + W=W, + offset=self.register_bytes() * nr, + W_SCALE=self.scale_registers()[nr], + ) + ret += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + # biases + for nr in range(0, N): + ret += output_scale.format( + W=W, offset=self.register_bytes() * nr, W_SCALE=self.w_registers()[nr] + ) + ret += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + # Intel gets points here for its fma instructions which can accumulate into + # any of the registers. For once, Intel has saner instructions than Arm. + for nr in range(0, N): + for mr in range(0, M): + ret += 'vfmadd213ps z{ACC}, {SCALE}, {BIAS}\n'.format( + ACC=accumulators[nr * M + mr], + SCALE=self.scale_registers()[nr], + BIAS=self.w_registers()[nr], + ) + + return ret + + def outer_loop_prepare(self, M, N): + W = self.w_ptr_register() + accumulators = self.acc_registers() + # outside the outer loop + zp_scale_load_push = ( + """mov {tmp_reg}, [{quantization_params_reg} + {zp_offset}] + vpbroadcastd {tmp_s_reg}, {tmp_reg} + vmovups zmmword ptr [rsp + {offset}], {tmp_s_reg}\n""" + ) + ret = '\n# Load quantization params pointer from stack\n' + ret += 'mov {quantization_params_reg}, [rsp + {offset}]\n'.format( + quantization_params_reg=self.quantization_params_register(), + offset=self.stack_size(M) + 88, + ) + for mr in range(0, M, 1): + ret += zp_scale_load_push.format( + tmp_reg=self.register_map_dword(self.tmp_gp_registers()[0]), + quantization_params_reg=self.quantization_params_register(), + tmp_s_reg=self.w_registers()[0], + offset=464 + mr * 64, + zp_offset=mr * 8, + ) + return ret + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with k_sum * input zero point.\n' + accumulators = self.acc_registers() + W = self.w_ptr_register() + + ksum_x16 = 'vmovaps {KSUM}, [{W} + {offset}]\n' + vksum = 'vpmulld z{ACC}, {KSUM}, ZMMWORD PTR [rsp + {offset}]\n' + + for nr in range(0, N): + ret += ksum_x16.format( + W=W, KSUM=self.w_registers()[nr], offset=self.register_bytes() * nr + ) + for nr in range(0, N): + for mr in range(0, M): + ret += vksum.format( + ACC=accumulators[nr * M + mr], + KSUM=self.w_registers()[nr], + pos=int((mr % 2) * 2), + offset=464 + mr * 64, + ) + + return ret + + def stack_size(self, M): + return 464 + M * 64 diff --git a/gemm_compiler/base_architecture.py b/gemm_compiler/base_architecture.py new file mode 100644 index 000000000000..3caccac4c8a8 --- /dev/null +++ b/gemm_compiler/base_architecture.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod + +from gemm_compiler import base_architecture as base_architecture + +"""Base architecture for GEMM microkernel generation""" + + +class BaseArchitecture: + + def __init__(self): + pass # Empty constructor + + def labels(self): + return [ + 'zero', + 'one', + 'two', + 'three', + 'four', + 'five', + 'six', + 'seven', + 'eight', + 'nine', + 'ten', + 'eleven', + ] + + @abstractmethod + def astride_register(self): + """Returns the register containing a_stride.""" + raise NotImplementedError + + @abstractmethod + def kc_register(self): + """Returns the register containing kc, the number of channels (the reduction dimensions).""" + raise NotImplementedError + + @abstractmethod + def k_register(self): + """Returns the register containing k, the current channel being processed.""" + raise NotImplementedError + + @abstractmethod + def cm_stride_register(self): + """Returns the register containing cm_stride.""" + raise NotImplementedError + + @abstractmethod + def am_registers(self): + """Returns the registers containing the pointers to each row of A (LHS).""" + raise NotImplementedError + + @abstractmethod + def a_ptr_register(self): + """Returns the register containing the A pointer.""" + raise NotImplementedError + + @abstractmethod + def c_ptr_register(self): + """Returns the register containing the C pointer.""" + raise NotImplementedError + + @abstractmethod + def cm_registers(self): + """Returns the registers containing the pointers to each row of C (Output).""" + raise NotImplementedError + + @abstractmethod + def acc_registers(self): + """Returns the accumulator registers.""" + raise NotImplementedError + + @abstractmethod + def w_ptr_register(self): + """Returns the register containing the weight's pointer.""" + raise NotImplementedError + + @abstractmethod + def min_register(self): + """Returns the register containing the min value for clamping.""" + raise NotImplementedError + + @abstractmethod + def max_register(self): + """Returns the register containing the max value for clamping.""" + raise NotImplementedError + + @abstractmethod + def nc_register(self): + """Returns the register containing nc, the number of output rows processed per iteration.""" + raise NotImplementedError + + @abstractmethod + def mr_register(self): + """Returns the register containing mr, the number of input rows processed per kernel call.""" + raise NotImplementedError + + @abstractmethod + def tmp_gp_registers(self): + """Returns some general purpose registers which may be used for storing temporary data.""" + raise NotImplementedError + + @abstractmethod + def jump_to_label(self, label): + """Jump to the given label.""" + raise NotImplementedError + + @abstractmethod + def function_name(self, M, N, isa): + """Returns the microkernel name.""" + raise NotImplementedError + + @abstractmethod + def header(self, M, N, prefix, isa): + """Returns the assembly header.""" + raise NotImplementedError + + @abstractmethod + def input_output_register_setup(self, M): + """Setup the input (A) and output (C) registers.""" + raise NotImplementedError + + @abstractmethod + def max_M_before_spilling(self): + """How large can M be before spilling A and C registers to the stack.""" + raise NotImplementedError + + @abstractmethod + def read_a_registers(self, M): + """Read the A registers from the stack.""" + raise NotImplementedError + + @abstractmethod + def increment_ptr(self, ptr, step): + """Increment the given pointer by step bytes.""" + raise NotImplementedError + + @abstractmethod + def zero_gp_register(self, reg): + """Zero the given general purpose register.""" + raise NotImplementedError + + @abstractmethod + def cmp_k_and_jump_if_less(self, label): + """If k is less than kc, then do another iteration of the inner loop.""" + raise NotImplementedError + + @abstractmethod + def epilogue(self, M, N, isa): + """Returns the function epilogue.""" + raise NotImplementedError + + @abstractmethod + def inner_loop(self, M, N): + """Returns the assemebly for the microkernel's inner loop.""" + raise NotImplementedError diff --git a/gemm_compiler/fma3_template.py b/gemm_compiler/fma3_template.py new file mode 100644 index 000000000000..2d59b4249c2e --- /dev/null +++ b/gemm_compiler/fma3_template.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import x64_template as arch + +"""All SIMD features for fma3.""" + + +class Fma3(arch.X64): + + def __init__(self): + pass # Empty constructor + + def isa(self): + return 'fma3' + + def register_bytes(self): + return 32 + + def prefix(self): + return 'y' + + def a_registers(self, idx): + registers = ['ymm2', 'ymm3', 'ymm4', 'ymm5'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['ymm14', 'ymm15'] + + def input_asm(self): + in_asm = { + 'loop': [ + 'vbroadcastss {AM}, DWORD PTR [{AM_ptr} + {a_offset}]\n', + ] + } + return in_asm + + def weights_asm(self): + w_asm = { + 'loop': [ + 'vmovaps {W}, [{W_ptr} + {offset}]\n', + ], + 'after': 'add {W}, {w_step}\n', + } + return w_asm + + def compute_asm(self): + c_asm = { + 'loop': ['vfmadd231ps y{ACC}, {A}, {W}\n'], + } + return c_asm + + def load_bias(self): + return 'vmovaps y{ACC}, [{W} + {offset}]\n' + + def copy_simd_register(self, prefix, src, dst): + return f'vmovaps {prefix}{dst}, {prefix}{src}\n' + + def clamp_min(self, reg, prefix): + max_reg = self.max_register() + return f'vminps {prefix}{reg}, {prefix}{max_reg}, {prefix}{reg}\n' + + def clamp_max(self, reg, prefix): + min_reg = self.min_register() + return f'vmaxps {prefix}{reg}, {prefix}{min_reg}, {prefix}{reg}\n' + + def store( + self, + M, + N, + ): + accumulators = self.acc_registers() + cm_registers = self.cm_registers() + nc_reg = self.nc_register() + nc_lo = self.register_map_byte(nc_reg) + N_STEP = 8 + N_COUNT = N // N_STEP + asm_string = """ + cmp {nc}, {n_step} + jl tail_{N_2} + """.format(n_step=N, N_2=N // 2, nc=nc_reg) + for mr in range(0, M): + asm_string += 'vmovups [{c_reg}], y{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for nr in range(1, N_COUNT): + asm_string += 'vmovups [{c_reg} + {offset}], y{ACC}\n'.format( + ACC=accumulators[M * nr + mr], + c_reg=cm_registers[mr], + offset=isa.register_bytes() * nr, + ) + for mr in range(0, M): + asm_string += 'add {cm}, {cn_stride}\n'.format( + cn_stride=cn_stride_reg, cm=cm_registers[mr] + ) + CHECK = """ + sub {nc}, {n_step} + jne {OUTER} + jmp return""".format(n_step=N, nc=nc_reg, OUTER=labels[M]) + asm_string += CHECK + N = N // 2 + if N * 2 > N_STEP: + asm_string += """ + tail_{N}: + test {nc_lo}, {N} + jz tail_{N_2}\n""".format(N=N, N_2=N // 2, nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovups [{c_reg}], y{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + N_COUNT = N // N_STEP + for nr in range(1, N_COUNT): + for mr in range(0, M): + asm_string += 'vmovups [{c_reg} + {offset}], y{ACC}\n'.format( + ACC=accumulators[M * nr + mr], + c_reg=cm_registers[mr], + offset=isa.register_bytes() * nr, + ) + for mr in range(0, M): + asm_string += 'vmovaps y{ACC0}, y{ACC1}\n'.format( + ACC0=accumulators[mr], ACC1=accumulators[mr + M * nr] + ) + for mr in range(0, M): + asm_string += 'add {cm}, 32\n'.format( + cn_stride=cn_stride_reg, cm=cm_registers[mr] + ) + asm_string += """ +tail_4: + test {nc_lo}, 4 + jz tail_2\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovups [{c_reg}], x{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'add {c_reg}, 16\n'.format(c_reg=cm_registers[mr]) + for mr in range(0, M): + asm_string += 'vextractf128 x{ACC}, y{ACC}, 1\n'.format( + ACC=accumulators[mr] + ) + asm_string += """ +tail_2: + test {nc_lo}, 2 + jz tail_1\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovlps QWORD PTR [{c_reg}], x{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'add {c_reg}, 8\n'.format(c_reg=cm_registers[mr]) + for mr in range(0, M): + asm_string += 'vmovhlps x{ACC}, x{ACC}, x{ACC}\n'.format( + ACC=accumulators[mr] + ) + asm_string += """ +tail_1: + test {nc_lo}, 1 + jz return\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'vmovss DWORD PTR [{c_reg}], x{ACC}\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + + return asm_string diff --git a/gemm_compiler/generate.py b/gemm_compiler/generate.py new file mode 100644 index 000000000000..097121c0034f --- /dev/null +++ b/gemm_compiler/generate.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import base_architecture + +"""Shared logic for assembly gemm microkernel generation.""" + + +def generate_gemm_microkernel( + M: int, N: int, isa: base_architecture.BaseArchitecture, output_file: str +): + elements_per_register = isa.n_step() + num_horizontal_registers = int(N / elements_per_register) + asm_string = isa.header(M, N, isa.prefix(), isa.isa()) + + k_register = isa.k_register() + acc_registers = isa.acc_registers() + w_ptr_reg = isa.w_ptr_register() + + # adjust inner loop + asm_string += isa.adjust_kc() + + # setup a{1}->a{M-1} & c{1]->c{M-1}registers + asm_string += isa.input_output_register_setup( + M=M, + ) + + # Pre outer loop preparation + asm_string += isa.outer_loop_prepare(M=M, N=num_horizontal_registers) + + # the outer loop label + asm_string += '\nouter_loop:\n' + asm_string += '# Zero k counter.\n' + asm_string += isa.zero_gp_register(k_register) + + # Read a registers from the stack if required + asm_string += isa.read_a_registers(M=M) + + # Initialize accumulators + asm_string += isa.init_accumulators( + M=M, + N=num_horizontal_registers, + ) + asm_string += isa.increment_ptr( + ptr=w_ptr_reg, step=isa.register_bytes() * num_horizontal_registers + ) + + # inner loop + asm_string += isa.inner_loop(M, N) + + # loop counter + asm_string += isa.cmp_k_and_jump_if_less(label='inner_loop') + + asm_string += isa.dequantize(M=M, N=num_horizontal_registers, W=w_ptr_reg) + + # min/max clamping + asm_string += '# Min/max clamping..\n' + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += isa.clamp_min( + reg=acc_registers[M * nr + mr], prefix=isa.prefix() + ) + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += isa.clamp_max( + reg=acc_registers[M * nr + mr], prefix=isa.prefix() + ) + + # store + asm_string += isa.store( + M=M, + N=N, + ) + + asm_string += isa.epilogue(M, N, isa) + + # Correctly indent the generated assembly. + lines = asm_string.splitlines() + stripped_lines = [line.lstrip() for line in lines] + # Indent all lines that are not labels. + stripped_lines = [ + ' ' + line + if not (line.endswith(':') or 'FUNCTION' in line or 'include' in line) + else line + for line in stripped_lines + ] + # Strip indentation from empty lines. + stripped_lines = ['' if line.isspace() else line for line in stripped_lines] + asm_string = '\n'.join(stripped_lines) + + with open(output_file, 'w') as f: + f.write(asm_string) diff --git a/gemm_compiler/generate_f32_gemm_microkernels.py b/gemm_compiler/generate_f32_gemm_microkernels.py new file mode 100644 index 000000000000..d984467ffb1b --- /dev/null +++ b/gemm_compiler/generate_f32_gemm_microkernels.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import avx512f_template +from gemm_compiler import fma3_template +from gemm_compiler import generate +from gemm_compiler import neonfma_template + +"""Generates f32 assembly gemm microkernels.""" + + +output_base = 'src/f32-gemm/gen/' + + +def generate_f32_gemm_microkernels(): + if '/bazel-out/' in os.getcwd(): + os.chdir(os.environ['BUILD_WORKING_DIRECTORY']) + + for nr in range(16, 33, 16): + for mr in range(1, 12): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=avx512f_template.Avx512F(), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x{nr}-minmax-asm-amd64-avx512f-broadcast.S', + ), + ) + + # not enough SIMD registers to go above 5x64 + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=64, + isa=avx512f_template.Avx512F(), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x64-minmax-asm-amd64-avx512f-broadcast.S', + ), + ) + + for nr in range(8, 17, 8): + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=neonfma_template.NeonFma(), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x{nr}-minmax-asm-aarch64-neonfma-ld32.S', + ), + ) diff --git a/gemm_compiler/generate_gemm_microkernels_main.py b/gemm_compiler/generate_gemm_microkernels_main.py new file mode 100644 index 000000000000..46dd710fcf39 --- /dev/null +++ b/gemm_compiler/generate_gemm_microkernels_main.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import generate_f32_gemm_microkernels as f32 +from gemm_compiler import generate_qd8_f32_qc8w_gemm_microkernels as qd8_f32_qc8w + +"""Generates all assembly gemm microkernels.""" + + +def main(args): + + f32.generate_f32_gemm_microkernels() + qd8_f32_qc8w.generate_qd8_f32_qc8w_gemm_microkernels() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/gemm_compiler/generate_qd8_f32_qc8w_gemm_microkernels.py b/gemm_compiler/generate_qd8_f32_qc8w_gemm_microkernels.py new file mode 100644 index 000000000000..a2c5ab13a59f --- /dev/null +++ b/gemm_compiler/generate_qd8_f32_qc8w_gemm_microkernels.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys + +from gemm_compiler import avx512vnni_template +from gemm_compiler import generate +from gemm_compiler import neondot_template + +"""Generates qd8-f32-qc8w assembly gemm microkernels.""" + + +output_base = 'src/qd8-f32-qc8w-gemm/gen/' + + +def generate_qd8_f32_qc8w_gemm_microkernels(): + if '/bazel-out/' in os.getcwd(): + os.chdir(os.environ['BUILD_WORKING_DIRECTORY']) + + for nr in range(8, 17, 8): + for mr in range(1, 5): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=neondot_template.NeonDot(), + output_file=os.path.join( + output_base, + f'qd8-f32-qc8w-gemm-{mr}x{nr}-minmax-asm-aarch64-neondot-ld32.S', + ), + ) + + for nr in range(16, 33, 16): + for mr in range(1, 12): + generate.generate_gemm_microkernel( + M=mr, + N=nr, + isa=avx512vnni_template.Avx512Vnni(), + output_file=os.path.join( + output_base, + f'qd8-f32-qc8w-gemm-{mr}x{nr}-minmax-asm-amd64-avx512vnni.S', + ), + ) + + # not enough SIMD registers to go above 5x64 + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=64, + isa=avx512vnni_template.Avx512Vnni(), + output_file=os.path.join( + output_base, + f'qd8-f32-qc8w-gemm-{mr}x64-minmax-asm-amd64-avx512vnni.S', + ), + ) diff --git a/gemm_compiler/neondot_template.py b/gemm_compiler/neondot_template.py new file mode 100644 index 000000000000..70400d9740a7 --- /dev/null +++ b/gemm_compiler/neondot_template.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import neonfma_template as isa + +"""All SIMD features for Aarch64 neondot.""" + + +class NeonDot(isa.NeonFma): + + def isa(self): + return 'neondot' + + def a_registers(self, idx): + registers = ['2', '3', '4', '5'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['6', '7', '8', '9'] + + def acc_registers(self): + return [ + '12', + '13', + '14', + '15', + '16', + '17', + '18', + '19', + '20', + '21', + '22', + '23', + '24', + '25', + '26', + '27', + '28', + '29', + '30', + ] + + def function_name(self, M, N, isa): + return f'xnn_qd8_f32_qc8w_gemm_minmax_ukernel_{M}x{N}c4__asm_aarch64_{isa}_lane\n' + + def zp_scale(self, pos): + regs = ['10', '11'] + return regs[pos] + + # kc = round_up_po2(kc, channels) + def adjust_kc(self): + channels = 4 + m = pow(2, 64) - channels + not_channels = f'0x{m:016X}' + ret = '# Round kc up to channels.\n' + ret += """add {kc_reg}, {kc_reg}, #{channels} + and {kc_reg}, {kc_reg}, #{not_channels}\n\n""".format( + kc_reg=self.kc_register(), + channels=channels - 1, + not_channels=not_channels, + ) + return ret + + def quantization_params(self): + return """ldr {quantization_params_reg}, [sp, 16]\n""".format( + quantization_params_reg=self.quantization_params_register() + ) + + def quantization_params_register(self): + return 'x24' + + def compute_asm(self): + c_asm = { + 'loop': ['sdot v{ACC}.4s, v{W}.16b, v{A}.4b[0]\n'], + } + return c_asm + + def dequantize(self, M, N, W): + accumulators = self.acc_registers() + ret = '\n# Convert from int32 to float.\n' + for nr in range(0, N * M): + ret += 'scvtf v{ACC}.4s, v{ACC}.4s\n'.format(ACC=accumulators[nr]) + ret += '# Multiply by input scale.\n' + for nr in range(0, N): + for mr in range(0, M): + ret += 'fmul v{ACC}.4s, v{ACC}.4s, v{zp_scale}.s[{pos}]\n'.format( + ACC=accumulators[nr * M + mr], + zp_scale=self.zp_scale(mr // 2), + pos=int((mr % 2) * 2) + 1, + ) + ret += '# Load weights scale.\n' + output_scale_pair = 'ldp q{W_SCALE_0}, q{W_SCALE_1}, [{W}, {offset}]\n' + # output scales + for nr in range(0, N, 2): + ret += output_scale_pair.format( + W=W, + offset=self.register_bytes() * nr, + W_SCALE_0=self.a_registers(nr), + W_SCALE_1=self.a_registers(nr + 1), + ) + ret += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + # biases + ret += '# Load biases.\n' + for nr in range(0, N, 2): + ret += output_scale_pair.format( + W=W, + offset=self.register_bytes() * nr, + W_SCALE_0=self.w_registers()[nr], + W_SCALE_1=self.w_registers()[nr + 1], + ) + ret += 'add {W}, {W}, {increment}\n'.format( + W=W, increment=self.register_bytes() * N + ) + # do mul + add here instead of fmla. + # fmla accumulaltes into the additional term, in this case the bias. This + # means that the bias must be copied before the fmla. + # From the Cortex X1 optimization guide, fmov takes 1 cycle with a + # throughput of 4 and fmla takes 4 cycles with a throughput of 4. This means + # 5 cycles for four movs + fmla. fadd takes 2 cycles with a throughput of 4 + # and fmul takes 3 cycles with a throughput of 4, for a total of 5 cycles + # for 4 results. + ret += "# Multiply by weight's scale.\n" + for nr in range(0, N): + for mr in range(0, M): + ret += 'fmul v{ACC}.4s, v{ACC}.4s, v{SCALE}.4s\n'.format( + ACC=accumulators[nr * M + mr], SCALE=self.a_registers(nr) + ) + ret += '# Add bias.\n' + for nr in range(0, N): + for mr in range(0, M): + ret += 'fadd v{ACC}.4s, v{ACC}.4s, v{BIAS}.4s\n'.format( + ACC=accumulators[nr * M + mr], BIAS=self.w_registers()[nr] + ) + + return ret + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with k_sum * input zero point.\n' + accumulators = self.acc_registers() + W = self.w_ptr_register() + zp_scale_x2 = 'ldr q{zp_scale}, [{quantization_params_reg}]\n' + zp_scale_x4 = ( + 'ldp q{zp_scale_0}, q{zp_scale_1}, [{quantization_params_reg}]\n' + ) + ksum_x8 = 'ldp q{KSUM_0}, q{KSUM_1}, [{W}, {offset}]\n' + vksum = 'mul v{ACC}.4s, v{KSUM}.4s, v{zp_scale}.s[{pos}]\n' + + mr = 0 + for mr in range(0, M - 1, 4): + ret += zp_scale_x4.format( + quantization_params_reg=self.quantization_params_register(), + zp_scale_0=self.zp_scale(mr), + zp_scale_1=self.zp_scale(mr + 1), + ) + if M % 2 == 1: + ret += zp_scale_x2.format( + quantization_params_reg=self.quantization_params_register(), + zp_scale=self.zp_scale(mr), + ) + for nr in range(0, N - 1, 2): + ret += ksum_x8.format( + W=W, + KSUM_0=self.a_registers(nr), + KSUM_1=self.a_registers(nr + 1), + offset=self.register_bytes() * nr, + ) + for nr in range(0, N): + for mr in range(0, M): + ret += vksum.format( + ACC=accumulators[nr * M + mr], + KSUM=self.a_registers(nr), + zp_scale=self.zp_scale(mr // 2), + pos=int((mr % 2) * 2), + ) + + return ret diff --git a/gemm_compiler/neonfma_template.py b/gemm_compiler/neonfma_template.py new file mode 100644 index 000000000000..1e88d9320fe1 --- /dev/null +++ b/gemm_compiler/neonfma_template.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from gemm_compiler import aarch64_template as arch + +"""All SIMD features for Aarch64 neondot.""" + + +class NeonFma(arch.Aarch64): + + def n_step(self): + return 4 + + def isa(self): + return 'neonfma' + + def register_bytes(self): + return 16 + + def prefix(self): + return 'v' + + def acc_registers(self): + return [ + '11', + '12', + '13', + '14', + '15', + '16', + '17', + '18', + '19', + '20', + '21', + '22', + '23', + '24', + '25', + '26', + '27', + '28', + '29', + '30', + ] + + def a_registers(self, idx): + registers = ['2', '3', '4', '5', '6'] + assert idx < len(registers) + return registers[idx] + + def w_registers(self): + return ['7', '8', '9', '10'] + + def input_asm(self): + in_asm = { + 'loop': [ + 'ldr d{AM}, [{AM_ptr}, {a_offset}]\n', + ] + } + return in_asm + + def weights_asm(self): + w_asm = { + 'loop': [ + 'ldr q{W}, [{W_ptr}], 16\n', + ], + 'loop_2': [ + 'ldp q{W}, q{W_1}, [{W_ptr}], 32\n', + ], + } + return w_asm + + def compute_asm(self): + c_asm = { + 'loop': ['fmla v{ACC}.4s, v{W}.4s, v{A}.s[0]\n'], + } + return c_asm + + def init_accumulators(self, M, N): + ret = '# Initialize accumulators with the biases.\n' + accumulators = self.acc_registers() + W = self.w_ptr_register() + single_bias = 'ldr q{ACC}, [{W}, {offset}]\n' + pair_bias = 'ldp q{ACC}, q{ACC_1}, [{W}, {offset}]\n' + + for nr in range(0, N - 1, 2): + ret += pair_bias.format( + W=W, + ACC=accumulators[nr * M], + ACC_1=accumulators[nr * M + M], + offset=self.register_bytes() * nr, + ) + if N % 2 != 0: + ret += single_bias.format( + W=W, + ACC=accumulators[(N - 1) * M], + offset=self.register_bytes() * (N - 1), + ) + for nr in range(0, N): + for mr in range(1, M): + ret += self.copy_simd_register( + prefix=self.prefix(), + src=accumulators[M * nr], + dst=accumulators[M * nr + mr], + ) + + return ret + + def copy_simd_register(self, prefix, src, dst): + return f'mov {prefix}{dst}.16b, {prefix}{src}.16b\n' + + def clamp_min(self, reg, prefix): + max_reg = self.max_register() + return f'fmin {prefix}{reg}.4s, {max_reg}.4s, {prefix}{reg}.4s\n' + + def clamp_max(self, reg, prefix): + min_reg = self.min_register() + return f'fmax {prefix}{reg}.4s, {min_reg}.4s, {prefix}{reg}.4s\n' + + def store( + self, + M, + N, + ): + accumulators = self.acc_registers() + cm_registers = self.cm_registers() + nc_reg = self.nc_register() + nc_lo = self.register_map_byte(nc_reg) + N_COUNT = N // self.n_step() + asm_string = """ + # Check whether full or partial store. + cmp {nc}, {n_step} + b.lo tail_{N_2}\n""".format(n_step=N, N_2=N // 2, nc=nc_reg) + for mr in range(0, M): + asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format( + ACC=accumulators[mr], + ACC_1=accumulators[M + mr], + c_reg=cm_registers[mr], + ) + for nr in range(2, N_COUNT, 2): + asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format( + ACC=accumulators[M * 2 + mr], + ACC_1=accumulators[M * 3 + mr], + c_reg=cm_registers[mr], + ) + CHECK = """ + sub {nc}, {nc}, {n_step} + b.ne outer_loop + b return""".format(n_step=N, nc=nc_reg) + asm_string += CHECK + N = N // 2 + if N * 2 > self.n_step(): + if N == 8: + asm_string += """ +\ntail_8: + tbz {nc_lo}, 3, tail_4\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format( + ACC=accumulators[mr], + ACC_1=accumulators[mr + M], + c_reg=cm_registers[mr], + ) + for mr in range(0, M): + asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format( + ACC0=accumulators[mr], ACC1=accumulators[mr + 2 * M] + ) + asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format( + ACC0=accumulators[mr + M], ACC1=accumulators[mr + 3 * M] + ) + asm_string += """ +\ntail_4: + tbz {nc_lo}, 2, tail_2\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'str q{ACC}, [{c_reg}], 16\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format( + ACC0=accumulators[mr], ACC1=accumulators[mr + M] + ) + asm_string += """ +\ntail_2: + tbz {nc_lo}, 1, tail_1\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'str d{ACC}, [{c_reg}], 8\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + for mr in range(0, M): + asm_string += 'dup d{ACC}, v{ACC}.d[1]\n'.format(ACC=accumulators[mr]) + asm_string += """ +\ntail_1: + tbz {nc_lo}, 0, return\n""".format(nc_lo=nc_lo) + for mr in range(0, M): + asm_string += 'str s{ACC}, [{c_reg}]\n'.format( + ACC=accumulators[mr], c_reg=cm_registers[mr] + ) + + return asm_string diff --git a/gemm_compiler/x64_template.py b/gemm_compiler/x64_template.py new file mode 100644 index 000000000000..ac6886eab05f --- /dev/null +++ b/gemm_compiler/x64_template.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# Copyright 2024 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from gemm_compiler import base_architecture as base_architecture + +"""All non SIMD features for x64.""" + + +class X64(base_architecture.BaseArchitecture): + + def astride_register(self): + return 'r8' + + def kc_register(self): + return 'rdx' + + def k_register(self): + return 'r11' + + def cm_stride_register(self): + return 'r11' + + def am_registers(self): + return [self.a_ptr_register()] + [ + 'rax', + 'r15', + 'r14', + 'r12', + 'r10', + 'r13', + 'rbx', + 'rbp', + 'r8', + 'rdi', + ] + + def a_ptr_register(self): + return 'rsi' + + def c_ptr_register(self): + return 'rsi' + + def cm_registers(self): + return [self.c_ptr_register()] + [ + 'rax', + 'r15', + 'r14', + 'r12', + 'r10', + 'r13', + 'rbx', + 'rbp', + 'r8', + 'rdi', + ] + + def acc_registers(self): + return [ + 'mm7', + 'mm8', + 'mm9', + 'mm14', + 'mm15', + 'mm16', + 'mm17', + 'mm18', + 'mm19', + 'mm20', + 'mm21', + 'mm22', + 'mm23', + 'mm24', + 'mm25', + 'mm26', + 'mm27', + 'mm28', + 'mm29', + 'mm30', + 'mm12', + 'mm13', + ] + + def w_ptr_register(self): + return 'r9' + + def min_register(self): + return 'mm0' + + def max_register(self): + return 'mm1' + + def nc_register(self): + return 'rcx' + + def mr_register(self): + return 'rdi' + + def tmp_gp_registers(self): + return ['rdi', 'r11'] + + def register_map_byte(self, reg): + """Maps 64 bit register names to their low 8 bits.""" + map = { + 'rax': 'al', + 'rcx': 'cl', + 'rdx': 'dl', + 'rbx': 'bl', + 'rsi': 'sil', + 'rdi': 'dil', + 'rsp': 'spl', + 'rbp': 'bpl', + 'r8': 'r8b', + 'r9': 'r9b', + 'r10': 'r10b', + 'r11': 'r11b', + 'r12': 'r12b', + 'r13': 'r13b', + 'r14': 'r14b', + 'r15': 'r15b', + } + return map[reg] + + def register_map_dword(self, reg): + """Maps 64 bit register names to their low 32 bits.""" + map = { + 'rax': 'eax', + 'rcx': 'ecx', + 'rdx': 'edx', + 'rbx': 'ebx', + 'rsi': 'esi', + 'rdi': 'edi', + 'rsp': 'esp', + 'rbp': 'ebp', + 'r8': 'r8d', + 'r9': 'r9d', + 'r10': 'r10d', + 'r11': 'r11d', + 'r12': 'r12d', + 'r13': 'r13d', + 'r14': 'r14d', + 'r15': 'r15d', + } + return map[reg] + + def jump_to_label(self, label): + return f'jmp {label}' + + def function_name(self, M, N, isa): + return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}__asm_amd64_{isa}_broadcast\n' + + def header(self, M, N, prefix, isa): + HEADER = '#include "xnnpack/assembly.h"\n\n' + + HEADER += 'BEGIN_FUNCTION ' + self.function_name(M, N, isa) + HEADER += """ + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss {prefix}mm0, DWORD PTR [r13] + vbroadcastss {prefix}mm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64]\n""".format(M=M, N=N, prefix=prefix, isa=isa) + return HEADER + + def input_output_register_setup(self, M): + registers = self.am_registers() + a_stride = self.astride_register() + c_stride = self.cm_stride_register() + INPUT_OUTPUT_REGISTER_SETUP = """ + # Clamp a & c pointers if mr <= {M} + mov {aM}, {aM_1} + add {aM}, {A_STRIDE} + mov {cM}, {cM_1} + add {cM}, {C_STRIDE} + cmp {mr_reg}, {M} + cmovle {aM}, {aM_1} + cmovle {cM}, {cM_1}\n""" + INPUT_OUTPUT_REGISTER_PUSH = """ + mov [rsp - {a_rsp_offset}], {aM} + mov [rsp - {c_rsp_offset}], {cM}\n""" + ret = '' + if self.stack_size(M) != 0: + ret += """sub rsp, {stack_size}\n""".format( + stack_size=self.stack_size(M) + ) + # Write rsi & r10 if required to the stack. + if M > self.max_M_before_spilling(): + ret += ( + '# Write rsi (a pointer) to the stack as we need the register.\n' + ) + ret += 'mov [rsp - 128], rsi\n' + ret += ( + '# Write r10 (c pointer) to the stack as we need the register.\n' + ) + ret += 'mov [rsp - 136], r10\n' + for mr in range(1, M): + # cycle size of 2 if required + if M > self.max_M_before_spilling(): + a_pos = mr % 2 + c_pos = (mr % 2) + self.max_M_before_spilling() + a_pos_1 = (mr + 1) % 2 + c_pos_1 = ((mr + 1) % 2) + self.max_M_before_spilling() + else: + a_pos = mr + c_pos = mr + self.max_M_before_spilling() + a_pos_1 = a_pos - 1 + c_pos_1 = c_pos - 1 + a_rsp_offset = 144 + (mr - 1) * 16 + ret += INPUT_OUTPUT_REGISTER_SETUP.format( + M=mr, + aM=registers[a_pos], + aM_1=registers[a_pos_1], + cM=registers[c_pos], + cM_1=registers[c_pos_1], + A_STRIDE=a_stride, + C_STRIDE=c_stride, + mr_reg=self.mr_register(), + a_rsp_offset=a_rsp_offset, + c_rsp_offset=a_rsp_offset + 8, + ) + if M > self.max_M_before_spilling(): + ret += INPUT_OUTPUT_REGISTER_PUSH.format( + M=mr, + aM=registers[a_pos], + aM_1=registers[a_pos_1], + cM=registers[c_pos], + cM_1=registers[c_pos_1], + A_STRIDE=a_stride, + C_STRIDE=c_stride, + mr_reg=self.mr_register(), + a_rsp_offset=a_rsp_offset, + c_rsp_offset=a_rsp_offset + 8, + ) + + return ret + + def max_M_before_spilling(self): + return 5 + + def read_a_registers(self, M): + registers = self.am_registers() + if M <= self.max_M_before_spilling(): + return '' + ret = '# Read a pointers from stack into GP registers.\n' + POP_A = 'mov {aM}, [rsp - {a_rsp_offset}]\n' + for mr in range(0, M): + a_rsp_offset = 128 + mr * 16 + ret += POP_A.format(aM=registers[mr], a_rsp_offset=a_rsp_offset) + ret += '\n' + return ret + + def increment_ptr(self, ptr, step): + return f'add {ptr}, {step}\n' + + def zero_gp_register(self, reg): + return f'mov {reg}, 0\n' + + def cmp_k_and_jump_if_less(self, label): + kc_register = self.kc_register() + k_register = self.k_register() + return """ + add {k_register}, 4 + cmp {kc_register}, {k_register} + jne {label}\n""".format( + label=label, k_register=k_register, kc_register=kc_register + ) + + def load_from_stack(self, reg, offset): + """Load 8 bytes from the given offset from the stack pointer to reg.""" + return f'mov {reg}, [rsp - {offset}]\n' + + def epilogue(self, M, N, isa): + restore_stack = '\nreturn:\n' + if isa.stack_size(M) != 0: + restore_stack += 'add rsp, {stack_ptr_sub}\n'.format( + stack_ptr_sub=isa.stack_size(M) + ) + restore_stack += """ + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION {function_name}""".format( + M=M, N=N, function_name=isa.function_name(M, N, isa.isa()) + ) + return restore_stack + + def stack_size(self, M): + """Returns the required stack storage space.""" + return 0 + + def inner_loop(self, M, N): + if M > self.max_M_before_spilling(): + return self.inner_loop_spill_gp(M, N) + else: + return self.inner_loop_small_M_N(M, N) diff --git a/gen/aarch64_microkernels.bzl b/gen/aarch64_microkernels.bzl index 52db94584237..d03ba1a5d09e 100644 --- a/gen/aarch64_microkernels.bzl +++ b/gen/aarch64_microkernels.bzl @@ -120,6 +120,7 @@ NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = [ "src/f32-dwconv/f32-dwconv-9p4c-minmax-asm-aarch64-neonfma.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2-prfm.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neon-ld128-acc2.S", + "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2-prfm.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc2.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld64-acc4-prfm.S", @@ -131,15 +132,24 @@ NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = [ "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S", "src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld128.S", "src/f32-gemm/gen/f32-gemm-1x12-minmax-asm-aarch64-neonfma-cortex-a53.S", + "src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-gemm/gen/f32-gemm-4x1-minmax-asm-aarch64-neonfma-ld128.S", "src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S", "src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-cortex-a75.S", "src/f32-gemm/gen/f32-gemm-4x2-minmax-asm-aarch64-neonfma-ld64.S", + "src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-gemm/gen/f32-gemm-4x12-minmax-asm-aarch64-neonfma-cortex-a53.S", + "src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75-prfm.S", "src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-cortex-a75.S", + "src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S", + "src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S", "src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-cortex-a75.S", "src/f32-gemm/gen/f32-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-gemm/gen/f32-gemm-goi-1x8-minmax-asm-aarch64-neonfma-ld128-prfm.S", @@ -225,6 +235,14 @@ NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = [ "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld64.S", "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x8-minmax-asm-aarch64-neonfma-ld128.S", "src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8-minmax-asm-aarch64-neonfma-ld64.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S", "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-asm-aarch64-neondot-ld64.S", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-asm-aarch64-neon-mlal-cortex-a53.S", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-asm-aarch64-neondot-ld32.S", diff --git a/gen/amd64_microkernels.bzl b/gen/amd64_microkernels.bzl new file mode 100644 index 000000000000..4d06fff1aa0f --- /dev/null +++ b/gen/amd64_microkernels.bzl @@ -0,0 +1,68 @@ +""" +Microkernel filenames lists for amd64. + +Auto-generated file. Do not edit! + Generator: tools/update-microkernels.py +""" + +PROD_AMD64_ASM_MICROKERNEL_SRCS = [ +] + +NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [ + "src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S", + "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S", +] + +AMD64_ASM_MICROKERNEL_SRCS = PROD_AMD64_ASM_MICROKERNEL_SRCS + NON_PROD_AMD64_ASM_MICROKERNEL_SRCS diff --git a/gen/avx256vnni_microkernels.bzl b/gen/avx256vnni_microkernels.bzl index cda5466c81dd..c89549b8bdde 100644 --- a/gen/avx256vnni_microkernels.bzl +++ b/gen/avx256vnni_microkernels.bzl @@ -14,6 +14,7 @@ PROD_AVX256VNNI_MICROKERNEL_SRCS = [ "src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c", "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c", + "src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni-prfm.c", "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c", ] @@ -110,6 +111,11 @@ NON_PROD_AVX256VNNI_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c", "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c", + "src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni-prfm.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni-prfm.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni-prfm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x8c8-minmax-fp32-avx256vnni-prfm.c", diff --git a/gen/avx2_microkernels.bzl b/gen/avx2_microkernels.bzl index 25a39db7efe7..91117654e68e 100644 --- a/gen/avx2_microkernels.bzl +++ b/gen/avx2_microkernels.bzl @@ -71,7 +71,6 @@ PROD_AVX2_MICROKERNEL_SRCS = [ "src/qu8-vcvt/gen/qu8-vcvt-avx2-u32.c", "src/qu8-vlrelu/gen/qu8-vlrelu-avx2-u32.c", "src/s8-vclamp/s8-vclamp-avx2-u128.c", - "src/s32-f32-vcvt/gen/s32-f32-vcvt-avx2.c", "src/u8-vclamp/u8-vclamp-avx2-u128.c", "src/x8-lut/gen/x8-lut-avx2-u128.c", "src/x8-transposec/gen/x8-transposec-32x32-reuse-switch-avx2.c", diff --git a/gen/avx512f_microkernels.bzl b/gen/avx512f_microkernels.bzl index 14d3e0916e5a..a9088ce1b8a6 100644 --- a/gen/avx512f_microkernels.bzl +++ b/gen/avx512f_microkernels.bzl @@ -60,7 +60,6 @@ PROD_AVX512F_MICROKERNEL_SRCS = [ "src/f32-vunary/gen/f32-vabs-avx512f.c", "src/f32-vunary/gen/f32-vneg-avx512f.c", "src/f32-vunary/gen/f32-vsqr-avx512f.c", - "src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c", "src/x32-packw/gen/x32-packw-x32-gemm-gio-avx512f-u8.c", "src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c", ] diff --git a/gen/avxvnni_microkernels.bzl b/gen/avxvnni_microkernels.bzl index b0ff2945c40e..d39ab43c555f 100644 --- a/gen/avxvnni_microkernels.bzl +++ b/gen/avxvnni_microkernels.bzl @@ -129,8 +129,14 @@ NON_PROD_AVXVNNI_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c", + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni-prfm.c", + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni-prfm.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni-prfm.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avxvnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni-prfm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni.c", diff --git a/gen/hvx_microkernels.bzl b/gen/hvx_microkernels.bzl index 691a12da673c..2cc7db52dfc3 100644 --- a/gen/hvx_microkernels.bzl +++ b/gen/hvx_microkernels.bzl @@ -104,6 +104,7 @@ NON_PROD_HVX_MICROKERNEL_SRCS = [ "src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u64.c", "src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u96.c", "src/qs8-vadd/gen/qs8-vadd-minmax-hvx-u128.c", + "src/x32-packw/gen/x32-packw-gio-hvx-u2.c", ] ALL_HVX_MICROKERNEL_SRCS = PROD_HVX_MICROKERNEL_SRCS + NON_PROD_HVX_MICROKERNEL_SRCS diff --git a/gen/microkernels.bzl b/gen/microkernels.bzl index 988614c49816..24c0255ad1f1 100644 --- a/gen/microkernels.bzl +++ b/gen/microkernels.bzl @@ -7,6 +7,7 @@ Auto-generated file. Do not edit! load("aarch32_microkernels.bzl", _AARCH32_ASM_MICROKERNEL_SRCS = "AARCH32_ASM_MICROKERNEL_SRCS", _NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS = "NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS", _PROD_AARCH32_ASM_MICROKERNEL_SRCS = "PROD_AARCH32_ASM_MICROKERNEL_SRCS") load("aarch64_microkernels.bzl", _AARCH64_ASM_MICROKERNEL_SRCS = "AARCH64_ASM_MICROKERNEL_SRCS", _NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = "NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS", _PROD_AARCH64_ASM_MICROKERNEL_SRCS = "PROD_AARCH64_ASM_MICROKERNEL_SRCS") +load("amd64_microkernels.bzl", _AMD64_ASM_MICROKERNEL_SRCS = "AMD64_ASM_MICROKERNEL_SRCS", _NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = "NON_PROD_AMD64_ASM_MICROKERNEL_SRCS", _PROD_AMD64_ASM_MICROKERNEL_SRCS = "PROD_AMD64_ASM_MICROKERNEL_SRCS") load("armsimd32_microkernels.bzl", _ALL_ARMSIMD32_MICROKERNEL_SRCS = "ALL_ARMSIMD32_MICROKERNEL_SRCS", _NON_PROD_ARMSIMD32_MICROKERNEL_SRCS = "NON_PROD_ARMSIMD32_MICROKERNEL_SRCS", _PROD_ARMSIMD32_MICROKERNEL_SRCS = "PROD_ARMSIMD32_MICROKERNEL_SRCS") load("avx256skx_microkernels.bzl", _ALL_AVX256SKX_MICROKERNEL_SRCS = "ALL_AVX256SKX_MICROKERNEL_SRCS", _NON_PROD_AVX256SKX_MICROKERNEL_SRCS = "NON_PROD_AVX256SKX_MICROKERNEL_SRCS", _PROD_AVX256SKX_MICROKERNEL_SRCS = "PROD_AVX256SKX_MICROKERNEL_SRCS") load("avx256vnni_microkernels.bzl", _ALL_AVX256VNNI_MICROKERNEL_SRCS = "ALL_AVX256VNNI_MICROKERNEL_SRCS", _NON_PROD_AVX256VNNI_MICROKERNEL_SRCS = "NON_PROD_AVX256VNNI_MICROKERNEL_SRCS", _PROD_AVX256VNNI_MICROKERNEL_SRCS = "PROD_AVX256VNNI_MICROKERNEL_SRCS") @@ -103,8 +104,10 @@ ALL_SSSE3_MICROKERNEL_SRCS = _ALL_SSSE3_MICROKERNEL_SRCS ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS = _ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS ALL_WASMSIMD_MICROKERNEL_SRCS = _ALL_WASMSIMD_MICROKERNEL_SRCS ALL_WASM_MICROKERNEL_SRCS = _ALL_WASM_MICROKERNEL_SRCS +AMD64_ASM_MICROKERNEL_SRCS = _AMD64_ASM_MICROKERNEL_SRCS NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS = _NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS = _NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS +NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = _NON_PROD_AMD64_ASM_MICROKERNEL_SRCS NON_PROD_ARMSIMD32_MICROKERNEL_SRCS = _NON_PROD_ARMSIMD32_MICROKERNEL_SRCS NON_PROD_AVX256SKX_MICROKERNEL_SRCS = _NON_PROD_AVX256SKX_MICROKERNEL_SRCS NON_PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS = _NON_PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS @@ -155,6 +158,7 @@ NON_PROD_WASMSIMD_MICROKERNEL_SRCS = _NON_PROD_WASMSIMD_MICROKERNEL_SRCS NON_PROD_WASM_MICROKERNEL_SRCS = _NON_PROD_WASM_MICROKERNEL_SRCS PROD_AARCH32_ASM_MICROKERNEL_SRCS = _PROD_AARCH32_ASM_MICROKERNEL_SRCS PROD_AARCH64_ASM_MICROKERNEL_SRCS = _PROD_AARCH64_ASM_MICROKERNEL_SRCS +PROD_AMD64_ASM_MICROKERNEL_SRCS = _PROD_AMD64_ASM_MICROKERNEL_SRCS PROD_ARMSIMD32_MICROKERNEL_SRCS = _PROD_ARMSIMD32_MICROKERNEL_SRCS PROD_AVX256SKX_MICROKERNEL_SRCS = _PROD_AVX256SKX_MICROKERNEL_SRCS PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS = _PROD_AVX256VNNIGFNI_MICROKERNEL_SRCS @@ -306,6 +310,7 @@ NON_PROD_C_SRCS_FOR_ARCH = { PROD_ASM_SRCS_FOR_ARCH = { "aarch32": PROD_AARCH32_ASM_MICROKERNEL_SRCS, "aarch64": PROD_AARCH64_ASM_MICROKERNEL_SRCS, + "amd64": PROD_AMD64_ASM_MICROKERNEL_SRCS, "wasm32": PROD_WASM32_ASM_MICROKERNEL_SRCS, "wasmrelaxedsimd32": PROD_WASMRELAXEDSIMD32_ASM_MICROKERNEL_SRCS, "wasmsimd32": PROD_WASMSIMD32_ASM_MICROKERNEL_SRCS, @@ -314,6 +319,7 @@ PROD_ASM_SRCS_FOR_ARCH = { NON_PROD_ASM_SRCS_FOR_ARCH = { "aarch32": NON_PROD_AARCH32_ASM_MICROKERNEL_SRCS, "aarch64": NON_PROD_AARCH64_ASM_MICROKERNEL_SRCS, + "amd64": NON_PROD_AMD64_ASM_MICROKERNEL_SRCS, "wasm32": NON_PROD_WASM32_ASM_MICROKERNEL_SRCS, "wasmrelaxedsimd32": NON_PROD_WASMRELAXEDSIMD32_ASM_MICROKERNEL_SRCS, "wasmsimd32": NON_PROD_WASMSIMD32_ASM_MICROKERNEL_SRCS, diff --git a/gen/neon_microkernels.bzl b/gen/neon_microkernels.bzl index 8e9366ca3140..9255594ed8e3 100644 --- a/gen/neon_microkernels.bzl +++ b/gen/neon_microkernels.bzl @@ -141,17 +141,12 @@ PROD_NEON_MICROKERNEL_SRCS = [ "src/s8-ibilinear/gen/s8-ibilinear-neon-c16.c", "src/s8-maxpool/s8-maxpool-9p8x-minmax-neon-c16.c", "src/s8-vclamp/s8-vclamp-neon-u64.c", - "src/s32-f32-vcvt/gen/s32-f32-vcvt-neon.c", "src/u8-ibilinear/gen/u8-ibilinear-neon-c8.c", "src/u8-ibilinear/gen/u8-ibilinear-neon-c16.c", "src/u8-maxpool/u8-maxpool-9p8x-minmax-neon-c16.c", "src/u8-rmax/u8-rmax-neon-u16.c", "src/u8-vclamp/u8-vclamp-neon-u64.c", "src/x8-transposec/gen/x8-transposec-16x16-reuse-dec-zip-neon.c", - "src/x8-zip/x8-zip-x2-neon.c", - "src/x8-zip/x8-zip-x3-neon.c", - "src/x8-zip/x8-zip-x4-neon.c", - "src/x8-zip/x8-zip-xm-neon.c", "src/x16-packw/gen/x16-packw-x8-gemm-goi-neon-ld4lane-u8-prfm.c", "src/x16-packw/gen/x16-packw-x16-gemm-goi-neon-ld4lane-u8-prfm.c", "src/x16-transposec/gen/x16-transposec-8x8-reuse-dec-zip-neon.c", @@ -161,10 +156,6 @@ PROD_NEON_MICROKERNEL_SRCS = [ "src/x32-packw/gen/x32-packw-x8s4-gemm-goi-neon-ld4lane-u4-prfm.c", "src/x32-transposec/gen/x32-transposec-4x4-reuse-dec-zip-neon.c", "src/x32-unpool/x32-unpool-neon.c", - "src/x32-zip/x32-zip-x2-neon.c", - "src/x32-zip/x32-zip-x3-neon.c", - "src/x32-zip/x32-zip-x4-neon.c", - "src/x32-zip/x32-zip-xm-neon.c", "src/x64-transposec/gen/x64-transposec-2x2-multi-dec-zip-neon.c", "src/x64-transposec/gen/x64-transposec-2x2-reuse-dec-zip-neon.c", "src/xx-fill/xx-fill-neon-u64.c", @@ -826,6 +817,7 @@ NON_PROD_NEON_MICROKERNEL_SRCS = [ "src/x16-transposec/gen/x16-transposec-8x8-reuse-mov-zip-neon.c", "src/x16-transposec/gen/x16-transposec-8x8-reuse-multi-zip-neon.c", "src/x16-transposec/gen/x16-transposec-8x8-reuse-switch-zip-neon.c", + "src/x32-packw/gen/x32-packw-gio-neon-u2.c", "src/x32-packw/gen/x32-packw-x2-gemm-goi-neon-ld2lane-u2.c", "src/x32-packw/gen/x32-packw-x8-gemm-goi-neon-ld4lane-u4.c", "src/x32-packw/gen/x32-packw-x8-gemm-goi-neon-ld4lane-u8-prfm.c", diff --git a/gen/neondot_aarch64_microkernels.bzl b/gen/neondot_aarch64_microkernels.bzl index ec73493e4d90..1abdb2afe929 100644 --- a/gen/neondot_aarch64_microkernels.bzl +++ b/gen/neondot_aarch64_microkernels.bzl @@ -9,6 +9,9 @@ PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x4c16s2-aarch64-neondot.c", "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-1x8c16s2-aarch64-neondot.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-1x8c16s2-aarch64-neondot.c", + "src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c4-aarch64-neondot.c", + "src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c8-aarch64-neondot.c", + "src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c4-mstep4-aarch64-neondot.c", ] NON_PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ diff --git a/gen/neoni8mm_microkernels.bzl b/gen/neoni8mm_microkernels.bzl index 35c03f7f01db..8a40b2f91611 100644 --- a/gen/neoni8mm_microkernels.bzl +++ b/gen/neoni8mm_microkernels.bzl @@ -24,6 +24,7 @@ PROD_NEONI8MM_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c8-minmax-neoni8mm.c", "src/qp8-f32-qb4w-gemm/qp8-f32-qb4w-gemm-minmax-16x4c16s2-mstep4-neoni8mm.c", "src/qp8-f32-qc4w-gemm/qp8-f32-qc4w-gemm-minmax-8x8c16s2-mstep2-neoni8mm.c", + "src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c8-mstep4-neoni8mm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-neoni8mm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x16c8-minmax-fp32-neoni8mm.c", "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c8-minmax-fp32-neoni8mm.c", diff --git a/gen/neonsme2_microkernels.bzl b/gen/neonsme2_microkernels.bzl index 2bb071f179db..0964a91321ad 100644 --- a/gen/neonsme2_microkernels.bzl +++ b/gen/neonsme2_microkernels.bzl @@ -6,6 +6,7 @@ Auto-generated file. Do not edit! """ PROD_NEONSME2_MICROKERNEL_SRCS = [ + "src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c", "src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c", "src/x32-pack-lh/x32-packlh-neonsme2.c", ] diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index 2243ba6160b0..2b65594d1c41 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -152,9 +152,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u1.c", "src/qs8-f32-vcvt/gen/qs8-f32-vcvt-scalar-u4.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c", - "src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c", - "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c", - "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-lrintf.c", @@ -218,7 +215,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/s8-ibilinear/gen/s8-ibilinear-scalar-c1.c", "src/s8-maxpool/s8-maxpool-9p8x-minmax-scalar-c1.c", "src/s8-vclamp/s8-vclamp-scalar-u4.c", - "src/s32-f32-vcvt/gen/s32-f32-vcvt-scalar.c", "src/u8-ibilinear/gen/u8-ibilinear-scalar-c1.c", "src/u8-lut32norm/u8-lut32norm-scalar.c", "src/u8-maxpool/u8-maxpool-9p8x-minmax-scalar-c1.c", @@ -231,10 +227,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/x8-packw/gen/x8-packw-x16-gemm-goi-scalar-u2.c", "src/x8-packw/gen/x8-packw-x32-gemm-goi-scalar-u2.c", "src/x8-transposec/gen/x8-transposec-2x4-scalar-int.c", - "src/x8-zip/x8-zip-x2-scalar.c", - "src/x8-zip/x8-zip-x3-scalar.c", - "src/x8-zip/x8-zip-x4-scalar.c", - "src/x8-zip/x8-zip-xm-scalar.c", "src/x16-packw/gen/x16-packw-x64-gemm-goi-scalar-int-u4.c", "src/x16-transposec/gen/x16-transposec-2x4-scalar-int.c", "src/x24-transposec/gen/x24-transposec-1x2-scalar.c", @@ -242,10 +234,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/x32-packw/gen/x32-packw-x4-gemm-goi-scalar-float-u4.c", "src/x32-transposec/gen/x32-transposec-2x4-scalar-int.c", "src/x32-unpool/x32-unpool-scalar.c", - "src/x32-zip/x32-zip-x2-scalar.c", - "src/x32-zip/x32-zip-x3-scalar.c", - "src/x32-zip/x32-zip-x4-scalar.c", - "src/x32-zip/x32-zip-xm-scalar.c", "src/x64-transposec/gen/x64-transposec-4x2-scalar-int.c", "src/xx-copy/xx-copy-scalar-memcpy.c", "src/xx-fill/xx-fill-scalar-u16.c", @@ -618,6 +606,9 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-packw/gen/qs8-packw-x32c4-gemm-gio-scalar.c", "src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x64c4-gemm-gio-scalar.c", + "src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-scalar.c", + "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-scalar.c", "src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x32c8-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-4p2c-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-5f5m5l1c1s1r-minmax-fp32-scalar-fmagic.c", diff --git a/gen/sse2_microkernels.bzl b/gen/sse2_microkernels.bzl index 8375c7af6565..f5d1f36709b6 100644 --- a/gen/sse2_microkernels.bzl +++ b/gen/sse2_microkernels.bzl @@ -82,18 +82,10 @@ PROD_SSE2_MICROKERNEL_SRCS = [ "src/u8-rmax/u8-rmax-sse2-u16.c", "src/u8-vclamp/u8-vclamp-sse2-u64.c", "src/x8-transposec/gen/x8-transposec-16x16-reuse-mov-sse2.c", - "src/x8-zip/x8-zip-x2-sse2.c", - "src/x8-zip/x8-zip-x3-sse2.c", - "src/x8-zip/x8-zip-x4-sse2.c", - "src/x8-zip/x8-zip-xm-sse2.c", "src/x16-transposec/gen/x16-transposec-8x8-reuse-multi-sse2.c", "src/x32-packw/gen/x32-packw-x2c4-gemm-goi-sse2-u4.c", "src/x32-packw/gen/x32-packw-x8-gemm-goi-sse2-u4.c", "src/x32-unpool/x32-unpool-sse2.c", - "src/x32-zip/x32-zip-x2-sse2.c", - "src/x32-zip/x32-zip-x3-sse2.c", - "src/x32-zip/x32-zip-x4-sse2.c", - "src/x32-zip/x32-zip-xm-sse2.c", "src/x64-transposec/gen/x64-transposec-2x2-multi-mov-sse2.c", "src/xx-fill/xx-fill-sse2-u64.c", "src/xx-pad/xx-pad-p16-sse2-u16.c", diff --git a/gen/sse41_microkernels.bzl b/gen/sse41_microkernels.bzl index e2e0127307b2..5864721f03a0 100644 --- a/gen/sse41_microkernels.bzl +++ b/gen/sse41_microkernels.bzl @@ -349,6 +349,7 @@ NON_PROD_SSE41_MICROKERNEL_SRCS = [ "src/qu8-vmulc/gen/qu8-vmulc-minmax-fp32-sse41-mul16-ld64-u8.c", "src/s8-ibilinear/gen/s8-ibilinear-sse41-c8.c", "src/u8-ibilinear/gen/u8-ibilinear-sse41-c8.c", + "src/x32-packw/gen/x32-packw-gio-sse41-u2.c", ] ALL_SSE41_MICROKERNEL_SRCS = PROD_SSE41_MICROKERNEL_SRCS + NON_PROD_SSE41_MICROKERNEL_SRCS diff --git a/gen/wasmsimd_microkernels.bzl b/gen/wasmsimd_microkernels.bzl index d364183040ea..2925413d9d6f 100644 --- a/gen/wasmsimd_microkernels.bzl +++ b/gen/wasmsimd_microkernels.bzl @@ -202,7 +202,6 @@ PROD_WASMSIMD_MICROKERNEL_SRCS = [ "src/s8-ibilinear/gen/s8-ibilinear-wasmsimd-dot16x2-c8.c", "src/s8-maxpool/s8-maxpool-9p8x-minmax-wasmsimd-c16.c", "src/s8-vclamp/s8-vclamp-wasmsimd-u64.c", - "src/s32-f32-vcvt/gen/s32-f32-vcvt-wasmsimd.c", "src/u8-ibilinear/gen/u8-ibilinear-wasmsimd-dot16x2-c8.c", "src/u8-maxpool/u8-maxpool-9p8x-minmax-wasmsimd-c16.c", "src/u8-vclamp/u8-vclamp-wasmsimd-u64.c", @@ -213,10 +212,6 @@ PROD_WASMSIMD_MICROKERNEL_SRCS = [ "src/x32-packw/gen/x32-packw-x8-gemm-goi-wasmsimd-u4.c", "src/x32-transposec/gen/x32-transposec-4x4-reuse-mov-wasmsimd.c", "src/x32-unpool/x32-unpool-wasmsimd.c", - "src/x32-zip/x32-zip-x2-wasmsimd.c", - "src/x32-zip/x32-zip-x3-wasmsimd.c", - "src/x32-zip/x32-zip-x4-wasmsimd.c", - "src/x32-zip/x32-zip-xm-wasmsimd.c", "src/xx-fill/xx-fill-wasmsimd-u64.c", "src/xx-pad/xx-pad-p16-wasmsimd-u16.c", ] @@ -1013,6 +1008,7 @@ NON_PROD_WASMSIMD_MICROKERNEL_SRCS = [ "src/x16-transposec/gen/x16-transposec-8x8-multi-switch-wasmsimd.c", "src/x16-transposec/gen/x16-transposec-8x8-reuse-multi-wasmsimd.c", "src/x16-transposec/gen/x16-transposec-8x8-reuse-switch-wasmsimd.c", + "src/x32-packw/gen/x32-packw-gio-wasmsimd-u2.c", "src/x32-packw/gen/x32-packw-x8s4-gemm-goi-wasmsimd-u4.c", "src/x32-packx/x32-packx-4x-wasmsimd.c", "src/x32-transposec/gen/x32-transposec-4x4-multi-mov-wasmsimd.c", diff --git a/include/xnnpack.h b/include/xnnpack.h index c97ffff57345..47ebbfc474e7 100644 --- a/include/xnnpack.h +++ b/include/xnnpack.h @@ -291,6 +291,9 @@ enum xnn_datatype { xnn_datatype_pfp32 = 13, /// BFloat16, i.e. the upper 16 bits of a float32. xnn_datatype_bf16 = 14, + /// Dynamically quantized 8-bit unsigned integer with per-batch quantization + /// parameters. + xnn_datatype_qduint8 = 15, }; /// Define a tensor-type Value and add it to a Subgraph. @@ -2677,42 +2680,6 @@ enum xnn_status xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w( const struct xnn_quantization_params* quantization_params, float* output); -enum xnn_status xnn_create_channel_shuffle_nc_x8( - size_t groups, - size_t group_channels, - size_t input_stride, - size_t output_stride, - uint32_t flags, - xnn_operator_t* channel_shuffle_op_out); - -enum xnn_status xnn_reshape_channel_shuffle_nc_x8( - xnn_operator_t channel_shuffle_op, - size_t batch_size, - pthreadpool_t threadpool); - -enum xnn_status xnn_setup_channel_shuffle_nc_x8( - xnn_operator_t channel_shuffle_op, - const void* input, - void* output); - -enum xnn_status xnn_create_channel_shuffle_nc_x32( - size_t groups, - size_t group_channels, - size_t input_stride, - size_t output_stride, - uint32_t flags, - xnn_operator_t* channel_shuffle_op_out); - -enum xnn_status xnn_reshape_channel_shuffle_nc_x32( - xnn_operator_t channel_shuffle_op, - size_t batch_size, - pthreadpool_t threadpool); - -enum xnn_status xnn_setup_channel_shuffle_nc_x32( - xnn_operator_t channel_shuffle_op, - const void* input, - void* output); - enum xnn_status xnn_create_constant_pad_nd_x8( const void* padding_value, uint32_t flags, diff --git a/scripts/generate-s32-f32-vcvt.sh b/scripts/generate-s32-f32-vcvt.sh deleted file mode 100755 index d2f5a7fa6c06..000000000000 --- a/scripts/generate-s32-f32-vcvt.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/sh -# Copyright 2024 Google LLC -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -################################## ARM NEON ################################### -tools/xngen src/s32-f32-vcvt/simd.c.in -D BATCH_TILES=4,8,12,16 -D ARCH=neon -o src/s32-f32-vcvt/gen/s32-f32-vcvt-neon.c - -################################# x86 AVX2 ################################# -tools/xngen src/s32-f32-vcvt/simd.c.in -D BATCH_TILES=8,16,24,32 -D ARCH=avx2 -o src/s32-f32-vcvt/gen/s32-f32-vcvt-avx2.c - -################################# x86 AVX512 ################################# -tools/xngen src/s32-f32-vcvt/simd.c.in -D BATCH_TILES=16,32,48,64 -D ARCH=avx512f -o src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c - -################################## WAsm SIMD ################################## -tools/xngen src/s32-f32-vcvt/simd.c.in -D BATCH_TILES=4,8,12,16 -D ARCH=wasmsimd -o src/s32-f32-vcvt/gen/s32-f32-vcvt-wasmsimd.c - -#################################### Scalar ################################### -tools/xngen src/s32-f32-vcvt/simd.c.in -D BATCH_TILES=1,2,3,4 -D ARCH=scalar -o src/s32-f32-vcvt/gen/s32-f32-vcvt-scalar.c - -wait diff --git a/scripts/generate-tests.sh b/scripts/generate-tests.sh index 341f669e8e48..d8b2c145b004 100755 --- a/scripts/generate-tests.sh +++ b/scripts/generate-tests.sh @@ -25,6 +25,7 @@ tools/generate-gemm-test.py --spec test/f32-qc8w-gemm.yaml --output-test tools/generate-gemm-test.py --spec test/f32-qc8w-gemm-relu.yaml --output-test test/f32-qc8w-gemm-relu.cc & tools/generate-gemm-test.py --spec test/f32-qc8w-gemm-minmax.yaml --output-test test/f32-qc8w-gemm-minmax.cc & +tools/generate-gemm-test.py --spec test/qu8-gemm-minmax-rndnu.yaml --output-test test/qu8-gemm-minmax-rndnu16.cc tools/generate-gemm-test.py --spec test/qu8-gemm-minmax-fp32.yaml --output-test test/qu8-gemm-minmax-fp32.cc --output-test test/qu8-gemm-minmax-fp32-2.cc --output-bench bench/qu8-gemm-fp32.cc & tools/generate-gemm-test.py --spec test/qu8-gemm-minmax-rndnu.yaml --output-test test/qu8-gemm-minmax-rndnu.cc --output-test test/qu8-gemm-minmax-rndnu-2.cc --output-bench bench/qu8-gemm-rndnu.cc & @@ -36,6 +37,7 @@ tools/generate-gemm-test.py --spec test/qd8-f32-qc4w-gemm-minmax.yaml --output-t tools/generate-gemm-test.py --spec test/qd8-f32-qb4w-gemm-minmax.yaml --output-test test/qd8-f32-qb4w-gemm-minmax.cc --output-bench bench/qd8-f32-qb4w-gemm.cc & tools/generate-gemm-test.py --spec test/qp8-f32-qc4w-gemm-minmax.yaml --output-test test/qp8-f32-qc4w-gemm-minmax.cc --output-bench bench/qp8-f32-qc4w-gemm.cc & +tools/generate-gemm-test.py --spec test/qp8-f32-qc8w-gemm-minmax.yaml --output-test test/qp8-f32-qc8w-gemm-minmax.cc --output-bench bench/qp8-f32-qc8w-gemm.cc & tools/generate-gemm-test.py --spec test/qp8-f32-qb4w-gemm-minmax.yaml --output-test test/qp8-f32-qb4w-gemm-minmax.cc --output-bench bench/qp8-f32-qb4w-gemm.cc & tools/generate-gemm-test.py --spec test/qs8-qc8w-gemm-minmax-fp32.yaml --output-test test/qs8-qc8w-gemm-minmax-fp32.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-2.cc --output-test test/qs8-qc8w-gemm-minmax-fp32-3.cc --output-bench bench/qs8-qc8w-gemm-fp32.cc & @@ -246,9 +248,6 @@ tools/generate-ibilinear-chw-test.py --spec test/f32-ibilinear-chw.yaml --output ### Tests for RAddExpMinusMax micro-kernels tools/generate-raddexpminusmax-test.py --spec test/f32-raddexpminusmax.yaml --output test/f32-raddexpminusmax.cc & -### Tests for RAddExtExp micro-kernels -tools/generate-raddextexp-test.py --spec test/f32-raddextexp.yaml --output test/f32-raddextexp.cc & - ### Tests for RAddStoreExpMinusMax micro-kernels tools/generate-raddstoreexpminusmax-test.py --spec test/f16-raddstoreexpminusmax.yaml --output test/f16-raddstoreexpminusmax.cc & tools/generate-raddstoreexpminusmax-test.py --spec test/f32-raddstoreexpminusmax.yaml --output test/f32-raddstoreexpminusmax.cc & diff --git a/scripts/generate-x32-packw.sh b/scripts/generate-x32-packw.sh index a9ce67e24366..8eca2c4ce0d6 100755 --- a/scripts/generate-x32-packw.sh +++ b/scripts/generate-x32-packw.sh @@ -141,4 +141,16 @@ tools/xngen src/x32-packw/rvv.c.in -D NR=m8 -D KBLOCK=2 -o src/x32-packw/gen/x32 tools/xngen src/x32-packw/rvv.c.in -D NR=m8 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x8v-gemm-goi-rvv-u4.c & tools/xngen src/x32-packw/rvv.c.in -D NR=m8 -D KBLOCK=8 -o src/x32-packw/gen/x32-packw-x8v-gemm-goi-rvv-u8.c & +################################## ARM NEON ################################### +tools/xngen src/x32-packw/gio-simd.c.in -D BATCH_TILES=4,8,12,16 -D PREFETCH=0 -D KBLOCK=2 -D ARCH=neon -o src/x32-packw/gen/x32-packw-gio-neon-u2.c + +################################# x86 SSE41 ################################# +tools/xngen src/x32-packw/gio-simd.c.in -D BATCH_TILES=4,8,12,16 -D PREFETCH=0 -D KBLOCK=2 -D ARCH=sse41 -o src/x32-packw/gen/x32-packw-gio-sse41-u2.c + +################################## WAsm SIMD ################################## +tools/xngen src/x32-packw/gio-simd.c.in -D BATCH_TILES=4,8,12,16 -D PREFETCH=0 -D KBLOCK=2 -D ARCH=wasmsimd -o src/x32-packw/gen/x32-packw-gio-wasmsimd-u2.c + +################################## Hexagon HVX ################################# +tools/xngen src/x32-packw/gio-simd.c.in -D BATCH_TILES=32,64,96,128 -D PREFETCH=0 -D KBLOCK=2 -D ARCH=hvx -o src/x32-packw/gen/x32-packw-gio-hvx-u2.c + wait diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 58260b2920b3..3b6e35fd745e 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -46,45 +46,67 @@ tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D tools/xngen src/x8-packw/kr-gio-scalar.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-gio-scalar.c & ### AVXVNNI micro-kernels +### C4 packing for AMX +tools/xngen src/x8-packw/c4-avxvnni.c.in -D NR=64 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/c4-avxvnni.c.in -D NR=64 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni-prfm.c & + ### C8 packing -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c & - -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c & - -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c & - -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c & + +### GIO packing +tools/xngen src/x8-packw/kr-gio-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni.c & +tools/xngen src/x8-packw/kr-gio-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni-prfm.c & # X8 packing -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2-prfm.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx.c & -tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=X8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx-prfm.c & tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT=MADD -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c & +# QC4W +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=0 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D VARIANT= -D PREFETCH=1 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=0 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D DATATYPE=QS4 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D VARIANT= -D PREFETCH=1 -o src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni-prfm.c & + ### WAsm Relaxed SIMD ### C8 packing tools/xngen src/x8-packw/kr-wasmdot.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c & tools/xngen src/x8-packw/kr-wasmdot.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP=128 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c & + wait diff --git a/src/configs/experiments-config.c b/src/configs/experiments-config.c index 24938ea65df7..3983c63de3bf 100644 --- a/src/configs/experiments-config.c +++ b/src/configs/experiments-config.c @@ -3,7 +3,7 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#include +#include "experiments-config.h" static struct xnn_experiment_config experiment_config = {0}; diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 4d131faafb87..6ee121464b91 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -41,7 +41,13 @@ static struct xnn_gemm_config qd8_f32_qb4w_gemm_config = {0}; static struct xnn_gemm_config qd8_f32_qc4w_gemm_config = {0}; static struct xnn_gemm_config qd8_f32_qc8w_gemm_config = {0}; static struct xnn_gemm_config qp8_f32_qc4w_gemm_config = {0}; +static struct xnn_gemm_config qp8_f32_qc8w_gemm_config = {0}; static struct xnn_gemm_config qp8_f32_qb4w_gemm_config = {0}; +static struct xnn_gemm_config qdu8_f32_qc4w_gemm_config = {0}; +static struct xnn_gemm_config qdu8_f16_qc8w_gemm_config = {0}; +static struct xnn_gemm_config qdu8_f32_qc8w_gemm_config = {0}; +static struct xnn_gemm_config qdu8_f32_qb4w_gemm_config = {0}; +static struct xnn_gemm_config qdu8_f16_qc4w_gemm_config = {0}; static struct xnn_gemm_config qs8_qc8w_gemm_config = {0}; static struct xnn_gemm_config qu8_gemm_config = {0}; @@ -58,7 +64,13 @@ XNN_INIT_ONCE_GUARD(qd8_f32_qb4w_gemm); XNN_INIT_ONCE_GUARD(qd8_f32_qc4w_gemm); XNN_INIT_ONCE_GUARD(qd8_f32_qc8w_gemm); XNN_INIT_ONCE_GUARD(qp8_f32_qc4w_gemm); +XNN_INIT_ONCE_GUARD(qp8_f32_qc8w_gemm); XNN_INIT_ONCE_GUARD(qp8_f32_qb4w_gemm); +XNN_INIT_ONCE_GUARD(qdu8_f32_qc4w_gemm); +XNN_INIT_ONCE_GUARD(qdu8_f16_qc8w_gemm); +XNN_INIT_ONCE_GUARD(qdu8_f32_qc8w_gemm); +XNN_INIT_ONCE_GUARD(qdu8_f32_qb4w_gemm); +XNN_INIT_ONCE_GUARD(qdu8_f16_qc4w_gemm); XNN_INIT_ONCE_GUARD(qs8_qc8w_gemm); XNN_INIT_ONCE_GUARD(qu8_gemm); @@ -66,6 +78,7 @@ static void init_f16_gemm_config(void) { #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon_fp16_arith) { f16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f16_gemm_minmax_ukernel_1x8__neonfp16arith_ld64); f16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f16_gemm_minmax_ukernel_6x8__neonfp16arith_ld64); @@ -80,6 +93,7 @@ static void init_f16_gemm_config(void) { #elif XNN_ARCH_ARM64 && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon_fp16_arith) { #if XNN_ENABLE_ASSEMBLY switch (cpuinfo_get_core(0)->uarch) { @@ -201,6 +215,7 @@ static void init_f16_gemm_config(void) { #elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512FP16 if (hardware_config->use_x86_avx512fp16) { f16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f16_gemm_minmax_ukernel_1x64__avx512fp16_broadcast); @@ -226,6 +241,7 @@ static void init_f16_gemm_config(void) { f16_gemm_config.nr = 16; } #endif + assert(f16_gemm_config.mr <= XNN_MAX_MR); } #if XNN_ARCH_WASMSIMD @@ -245,13 +261,14 @@ static void init_f16_gemm_config(void) { static void init_pf32_gemm_config(void) { #if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI - const struct xnn_hardware_config* hardware_config = - xnn_init_hardware_config(); + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (XNN_ENABLE_ARM_SME2 && hardware_config->use_arm_sme2) { #if XNN_ENABLE_ARM_SME2 const size_t mr = xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2_get_mr(); const size_t nr = xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2_get_nr(); + pf32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_pf32_gemm_minmax_ukernel_1x32__neonsme2); pf32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(nr)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2); pf32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; pf32_gemm_config.pack_weights_and_biases = xnn_pack_kai_f32_weights_and_biases; @@ -261,6 +278,7 @@ static void init_pf32_gemm_config(void) { pf32_gemm_config.nr = nr; #endif // XNN_ENABLE_ARM_SME2 } + assert(pf32_gemm_config.mr <= XNN_MAX_MR); #endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI } @@ -268,6 +286,7 @@ static void init_f32_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { #if XNN_ENABLE_ASSEMBLY switch (cpuinfo_get_uarch(0)->uarch) { @@ -675,6 +694,7 @@ static void init_f32_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512F if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) { f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast); @@ -739,6 +759,7 @@ static void init_f32_gemm_config(void) { #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { #if XNN_ARCH_WASMRELAXEDSIMD f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); @@ -830,6 +851,7 @@ static void init_f32_gemm_config(void) { #elif XNN_ARCH_WASM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_2x4__scalar); @@ -870,6 +892,7 @@ static void init_f32_gemm_config(void) { #elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_riscv_vector) { f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x4v__rvv); f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4v__rvv); @@ -902,12 +925,14 @@ static void init_f32_gemm_config(void) { f32_gemm_config.mr = 4; f32_gemm_config.nr = 4; #endif + assert(f32_gemm_config.mr <= XNN_MAX_MR); } static void init_f32_gemm_nr2_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { f32_gemm_nr2_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x2__neon_lane_ld64); f32_gemm_nr2_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x2__neon_lane_ld64); @@ -967,6 +992,7 @@ static void init_f32_gemm_nr2_config(void) { #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { #if XNN_ARCH_WASMRELAXEDSIMD f32_gemm_nr2_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x2c4__wasmrelaxedsimd_fma); @@ -1027,6 +1053,7 @@ static void init_f32_gemm_nr2_config(void) { f32_gemm_nr2_config.mr = 4; f32_gemm_nr2_config.nr = 2; #endif + assert(f32_gemm_nr2_config.mr <= XNN_MAX_MR); } static void init_f32_qc4w_gemm_config(void) { @@ -1034,6 +1061,7 @@ static void init_f32_qc4w_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { f32_qc4w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc4w_gemm_minmax_ukernel_1x8__neon_lane_ld64); f32_qc4w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc4w_gemm_minmax_ukernel_4x8__neon_lane_ld64); @@ -1060,6 +1088,7 @@ static void init_f32_qc4w_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512SKX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { f32_qc4w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc4w_gemm_minmax_ukernel_1x32__avx512skx_broadcast); @@ -1121,12 +1150,14 @@ static void init_f32_qc4w_gemm_config(void) { f32_qc4w_gemm_config.mr = 4; f32_qc4w_gemm_config.nr = 4; #endif + assert(f32_qc4w_gemm_config.mr <= XNN_MAX_MR); } static void init_f32_qc8w_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { f32_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc8w_gemm_minmax_ukernel_1x8__neon_lane_ld64); f32_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc8w_gemm_minmax_ukernel_4x8__neon_lane_ld64); @@ -1223,6 +1254,7 @@ static void init_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512SKX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { f32_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc8w_gemm_minmax_ukernel_1x32__avx512skx_broadcast); @@ -1270,6 +1302,7 @@ static void init_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { #if XNN_ARCH_WASMRELAXEDSIMD f32_qc8w_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_qc8w_gemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); @@ -1336,6 +1369,79 @@ static void init_f32_qc8w_gemm_config(void) { f32_qc8w_gemm_config.mr = 4; f32_qc8w_gemm_config.nr = 4; #endif + assert(f32_qc8w_gemm_config.mr <= XNN_MAX_MR); +} + +static void init_qdu8_f16_qc4w_gemm_config(void) { + // Use the same packing function throughout. + qdu8_f16_qc4w_gemm_config.pack_weights_and_biases = + (xnn_pack_weights_and_biases_fn)xnn_pack_qs4_weights_and_biases; + qdu8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = + (xnn_packed_stride_weights_and_biases_fn) + xnn_packed_stride_qs4_weights_and_biases; + qdu8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; + qdu8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4w_gemm_goi_w; + #if (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_AVX256VNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256vnni) { + qdu8_f16_qc4w_gemm_config.arch = xnn_arch_x86_avx256vnni; + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni); + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni); + qdu8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; + qdu8_f16_qc4w_gemm_config.mr = 8; + qdu8_f16_qc4w_gemm_config.nr = 8; + qdu8_f16_qc4w_gemm_config.log2_kr = 3; + qdu8_f16_qc4w_gemm_config.planes = 2; + } else + #endif + #if XNN_ENABLE_AVXVNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { + qdu8_f16_qc4w_gemm_config.arch = xnn_arch_x86_avxvnni; + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); + qdu8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; + qdu8_f16_qc4w_gemm_config.mr = 5; + qdu8_f16_qc4w_gemm_config.nr = 8; + qdu8_f16_qc4w_gemm_config.log2_kr = 3; + qdu8_f16_qc4w_gemm_config.planes = 2; + } else + #endif + #if XNN_ENABLE_AVX256SKX + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { + qdu8_f16_qc4w_gemm_config.arch = xnn_arch_x86_avx256skx; + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm); + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm); + qdu8_f16_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; + qdu8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; + qdu8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; + qdu8_f16_qc4w_gemm_config.mr = 8; + qdu8_f16_qc4w_gemm_config.nr = 8; + qdu8_f16_qc4w_gemm_config.log2_kr = 3; + qdu8_f16_qc4w_gemm_config.planes = 2; + } else + #endif + if (hardware_config->use_x86_avx2) { + qdu8_f16_qc4w_gemm_config.arch = xnn_arch_x86_avx2; + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm); + qdu8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm); + qdu8_f16_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; + qdu8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; + qdu8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; + qdu8_f16_qc4w_gemm_config.mr = 4; + qdu8_f16_qc4w_gemm_config.nr = 8; + qdu8_f16_qc4w_gemm_config.log2_kr = 3; + qdu8_f16_qc4w_gemm_config.planes = 2; + } + #endif + assert(qdu8_f16_qc4w_gemm_config.mr <= XNN_MAX_MR); + assert(qdu8_f16_qc4w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); } static void init_qd8_f16_qc4w_gemm_config(void) { @@ -1350,6 +1456,7 @@ static void init_qd8_f16_qc4w_gemm_config(void) { #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot && hardware_config->use_arm_neon_fp16_arith) { #if XNN_ENABLE_ARM_DOTPROD @@ -1373,6 +1480,7 @@ static void init_qd8_f16_qc4w_gemm_config(void) { #elif XNN_ARCH_ARM64 && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { #if XNN_ENABLE_ARM_I8MM qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16c8__neoni8mm); @@ -1401,68 +1509,9 @@ static void init_qd8_f16_qc4w_gemm_config(void) { qd8_f16_qc4w_gemm_config.nr = 16; qd8_f16_qc4w_gemm_config.planes = 2; } - #elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - assert(hardware_config != NULL); - #if XNN_ENABLE_AVX256VNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256vnni) { - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni); - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni); - qd8_f16_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; - qd8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__scalar; - qd8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; - qd8_f16_qc4w_gemm_config.mr = 8; - qd8_f16_qc4w_gemm_config.nr = 8; - qd8_f16_qc4w_gemm_config.log2_kr = 3; - qd8_f16_qc4w_gemm_config.planes = 2; - } else - #endif - #if XNN_ENABLE_AVXVNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); - qd8_f16_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; - qd8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__scalar; - qd8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; - qd8_f16_qc4w_gemm_config.mr = 5; - qd8_f16_qc4w_gemm_config.nr = 8; - qd8_f16_qc4w_gemm_config.log2_kr = 3; - qd8_f16_qc4w_gemm_config.planes = 2; - } else - #endif - #if XNN_ENABLE_AVX256SKX - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm); - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm); - qd8_f16_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; - qd8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; - qd8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; - qd8_f16_qc4w_gemm_config.mr = 8; - qd8_f16_qc4w_gemm_config.nr = 8; - qd8_f16_qc4w_gemm_config.log2_kr = 3; - qd8_f16_qc4w_gemm_config.planes = 2; - } else - #endif - if (hardware_config->use_x86_avx2) { - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm); - qd8_f16_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm); - qd8_f16_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; - qd8_f16_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; - qd8_f16_qc4w_gemm_config.init.f16_qc4w = xnn_init_f16_qc4w_minmax_scalar_params; - qd8_f16_qc4w_gemm_config.mr = 4; - qd8_f16_qc4w_gemm_config.nr = 8; - qd8_f16_qc4w_gemm_config.log2_kr = 3; - qd8_f16_qc4w_gemm_config.planes = 2; - } #endif + assert(qd8_f16_qc4w_gemm_config.mr <= XNN_MAX_MR); + assert(qd8_f16_qc4w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); } static void init_qd8_f16_qb4w_gemm_config(void) { @@ -1471,6 +1520,7 @@ static void init_qd8_f16_qb4w_gemm_config(void) { #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot && hardware_config->use_arm_neon_fp16_arith) { #if XNN_ENABLE_ARM_DOTPROD @@ -1494,6 +1544,7 @@ static void init_qd8_f16_qb4w_gemm_config(void) { #elif XNN_ARCH_ARM64 && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { #if XNN_ENABLE_ARM_I8MM qd8_f16_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16c8__neoni8mm); @@ -1525,6 +1576,7 @@ static void init_qd8_f16_qb4w_gemm_config(void) { #elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_x86_avx2) { qd8_f16_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8c8__avx2); qd8_f16_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(3)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c8__avx2); @@ -1535,6 +1587,8 @@ static void init_qd8_f16_qb4w_gemm_config(void) { qd8_f16_qb4w_gemm_config.planes = 2; } #endif + assert(qd8_f16_qb4w_gemm_config.mr <= XNN_MAX_MR); + assert(qd8_f16_qb4w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); } static void init_qd8_f32_qc4w_gemm_config(void) { @@ -1546,6 +1600,7 @@ static void init_qd8_f32_qc4w_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { #if XNN_ENABLE_ARM_DOTPROD @@ -1576,6 +1631,7 @@ static void init_qd8_f32_qc4w_gemm_config(void) { #elif XNN_ARCH_ARM64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { #if XNN_ENABLE_ARM_I8MM qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__neoni8mm); @@ -1605,10 +1661,12 @@ static void init_qd8_f32_qc4w_gemm_config(void) { qd8_f32_qc4w_gemm_config.planes = 2; } #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - assert(hardware_config != NULL); #if XNN_ENABLE_AVX512AMX + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512amx) { + qd8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avx512amx; qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x64c4__avx512amx); qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx); qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; @@ -1618,108 +1676,7 @@ static void init_qd8_f32_qc4w_gemm_config(void) { qd8_f32_qc4w_gemm_config.planes = 2; } else #endif // XNN_ENABLE_AVX512AMX - #if XNN_ENABLE_AVX512VNNIGFNI - // Zen4 has gfni but is slower and 8x16 works better on zen4. 14x16 is faster on Sapphire Rapids - // TODO(b/361288044): Re-enable once fixed. - if (false && !XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnnigfni && cpuinfo_get_core(0)->uarch != cpuinfo_uarch_zen4) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(14)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__scalar; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 14; - qd8_f32_qc4w_gemm_config.nr = 16; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else - #endif // XNN_ENABLE_AVX512VNNIGFNI - #if XNN_ENABLE_AVX512VNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__scalar; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 8; - qd8_f32_qc4w_gemm_config.nr = 16; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else - #endif - #if XNN_ENABLE_AVXVNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__scalar; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 5; - qd8_f32_qc4w_gemm_config.nr = 8; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else - #endif - #if XNN_ENABLE_AVX512SKX - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 8; - qd8_f32_qc4w_gemm_config.nr = 16; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else - #endif - #if XNN_ENABLE_AVX256SKX - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 8; - qd8_f32_qc4w_gemm_config.nr = 8; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else - #endif - if (hardware_config->use_x86_avx2) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 4; - qd8_f32_qc4w_gemm_config.nr = 8; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else if (hardware_config->use_x86_ssse3) { - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd); - qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd); - qd8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; - qd8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; - qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; - qd8_f32_qc4w_gemm_config.mr = 4; - qd8_f32_qc4w_gemm_config.nr = 4; - qd8_f32_qc4w_gemm_config.log2_kr = 3; - qd8_f32_qc4w_gemm_config.planes = 2; - } else { + { qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse2_ld128); qd8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse2_ld128); qd8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; @@ -1751,6 +1708,8 @@ static void init_qd8_f32_qc4w_gemm_config(void) { qd8_f32_qc4w_gemm_config.nr = 4; qd8_f32_qc4w_gemm_config.planes = 2; #endif + assert(qd8_f32_qc4w_gemm_config.mr <= XNN_MAX_MR); + assert(qd8_f32_qc4w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); } static void init_qp8_f32_qc4w_gemm_config(void) { @@ -1786,6 +1745,41 @@ static void init_qp8_f32_qc4w_gemm_config(void) { qp8_f32_qc4w_gemm_config.mr_packed = 1; #endif // XNN_ENABLE_ARM_DOTPROD } + assert(qp8_f32_qc4w_gemm_config.mr <= XNN_MAX_MR); +#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI +} + +static void init_qp8_f32_qc8w_gemm_config(void) { +#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI + const struct xnn_hardware_config* hardware_config = + xnn_init_hardware_config(); + assert(hardware_config != NULL); + if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { +#if XNN_ENABLE_ARM_I8MM + qp8_f32_qc8w_gemm_config.minmax.qp8gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_ukernel(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot); + qp8_f32_qc8w_gemm_config.minmax.qp8gemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_qp8gemm_ukernel(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4); + qp8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + qp8_f32_qc8w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qs8_weights_and_biases; + qp8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qs8_weights_and_biases; + qp8_f32_qc8w_gemm_config.mr = 16; + qp8_f32_qc8w_gemm_config.nr = 4; + qp8_f32_qc8w_gemm_config.log2_kr = 3; + qp8_f32_qc8w_gemm_config.mr_packed = 4; +#endif // XNN_ENABLE_ARM_I8MM + } else if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { +#if XNN_ENABLE_ARM_DOTPROD + qp8_f32_qc8w_gemm_config.minmax.qp8gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_qp8gemm_ukernel(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot); + qp8_f32_qc8w_gemm_config.minmax.qp8gemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_qp8gemm_ukernel(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4); + qp8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + qp8_f32_qc8w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qs8_weights_and_biases; + qp8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qs8_weights_and_biases; + qp8_f32_qc8w_gemm_config.mr = 16; + qp8_f32_qc8w_gemm_config.nr = 4; + qp8_f32_qc8w_gemm_config.log2_kr = 2; + qp8_f32_qc8w_gemm_config.mr_packed = 4; +#endif // XNN_ENABLE_ARM_DOTPROD + } + assert(qp8_f32_qc8w_gemm_config.mr <= XNN_MAX_MR); #endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI } @@ -1822,15 +1816,56 @@ static void init_qp8_f32_qb4w_gemm_config(void) { qp8_f32_qb4w_gemm_config.mr_packed = 1; #endif // XNN_ENABLE_ARM_DOTPROD } + assert(qp8_f32_qb4w_gemm_config.mr <= XNN_MAX_MR); #endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI } +static void init_qdu8_f32_qb4w_gemm_config(void) { + qdu8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; + #if XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_AVX512VNNIGFNI + // Zen4 has gfni but is slower and 8x16 works better on zen4. 14x16 is faster on Sapphire Rapids + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnnigfni && cpuinfo_get_core(0)->uarch != cpuinfo_uarch_zen4) { + qdu8_f32_qb4w_gemm_config.arch = xnn_arch_x86_avx512vnnigfni; + qdu8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm); + qdu8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(14)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm); + qdu8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; + qdu8_f32_qb4w_gemm_config.mr = 14; + qdu8_f32_qb4w_gemm_config.nr = 16; + qdu8_f32_qb4w_gemm_config.log2_kr = 3; + qdu8_f32_qb4w_gemm_config.planes = 2; + } else + #endif // XNN_ENABLE_AVX512VNNIGFNI + #if XNN_ENABLE_AVX512VNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { + qdu8_f32_qb4w_gemm_config.arch = xnn_arch_x86_avx512vnni; + qdu8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); + qdu8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm); + qdu8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; + qdu8_f32_qb4w_gemm_config.mr = 8; + qdu8_f32_qb4w_gemm_config.nr = 16; + qdu8_f32_qb4w_gemm_config.log2_kr = 3; + qdu8_f32_qb4w_gemm_config.planes = 2; + } + #else + { + ; + } + #endif + assert(qdu8_f32_qb4w_gemm_config.mr <= XNN_MAX_MR); + #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +} + static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.pack_gemm_goi_bl = (xnn_packw_gemm_goi_bl_ukernel_fn) xnn_pack_qs8_qb4w_gemm_goi_w; #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { #if XNN_ENABLE_ARM_DOTPROD @@ -1861,6 +1896,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { #elif XNN_ARCH_ARM64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { #if XNN_ENABLE_ARM_I8MM qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__neoni8mm); @@ -1892,30 +1928,9 @@ static void init_qd8_f32_qb4w_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); - #if XNN_ENABLE_AVX512VNNIGFNI - // Zen4 has gfni but is slower and 8x16 works better on zen4. 14x16 is faster on Sapphire Rapids - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnnigfni && cpuinfo_get_core(0)->uarch != cpuinfo_uarch_zen4) { - qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm); - qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(14)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm); - qd8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; - qd8_f32_qb4w_gemm_config.mr = 14; - qd8_f32_qb4w_gemm_config.nr = 16; - qd8_f32_qb4w_gemm_config.log2_kr = 3; - qd8_f32_qb4w_gemm_config.planes = 2; - } else - #endif // XNN_ENABLE_AVX512VNNIGFNI - #if XNN_ENABLE_AVX512VNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { - qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); - qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm); - qd8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; - qd8_f32_qb4w_gemm_config.mr = 8; - qd8_f32_qb4w_gemm_config.nr = 16; - qd8_f32_qb4w_gemm_config.log2_kr = 3; - qd8_f32_qb4w_gemm_config.planes = 2; - } else - #endif + (void) hardware_config; // May be unused. if (hardware_config->use_x86_avx2) { + qd8_f32_qb4w_gemm_config.arch = xnn_arch_x86_avx2; qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8c8__avx2); qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(3)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x8c8__avx2); qd8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; @@ -1924,6 +1939,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.log2_kr = 3; qd8_f32_qb4w_gemm_config.planes = 2; } else if (hardware_config->use_x86_avx) { + qd8_f32_qb4w_gemm_config.arch = xnn_arch_x86_avx; qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__avx_ld128); qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__avx_ld128); qd8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; @@ -1932,6 +1948,7 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.log2_kr = 3; qd8_f32_qb4w_gemm_config.planes = 1; } else if (hardware_config->use_x86_sse4_1) { + qd8_f32_qb4w_gemm_config.arch = xnn_arch_x86_sse4_1; qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128); qd8_f32_qb4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(3)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128); qd8_f32_qb4w_gemm_config.init.f32_qb4w = xnn_init_f32_qb4w_minmax_scalar_params; @@ -1956,6 +1973,8 @@ static void init_qd8_f32_qb4w_gemm_config(void) { qd8_f32_qb4w_gemm_config.nr = 4; qd8_f32_qb4w_gemm_config.planes = 2; #endif + assert(qd8_f32_qb4w_gemm_config.mr <= XNN_MAX_MR); + assert(qd8_f32_qb4w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); } static void init_qd8_f16_qc8w_gemm_config(void) { @@ -1967,6 +1986,7 @@ static void init_qd8_f16_qc8w_gemm_config(void) { #if XNN_ARCH_ARM && XNN_ENABLE_ARM_FP16_VECTOR && XNN_ENABLE_ARM_FP16_SCALAR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { #if XNN_ENABLE_ASSEMBLY if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot && hardware_config->use_arm_neon_fp16_arith) { @@ -2049,6 +2069,7 @@ static void init_qd8_f16_qc8w_gemm_config(void) { #elif XNN_ARCH_ARM64 && XNN_ENABLE_ARM_FP16_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_PLATFORM_IOS || XNN_PLATFORM_MAC || XNN_PLATFORM_WINDOWS #if XNN_ENABLE_ASSEMBLY if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { @@ -2235,6 +2256,7 @@ static void init_qd8_f16_qc8w_gemm_config(void) { #elif (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512AMX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512amx) { qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx); @@ -2245,47 +2267,15 @@ static void init_qd8_f16_qc8w_gemm_config(void) { qd8_f16_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. qd8_f16_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. qd8_f16_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qd8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar; + qd8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni_prfm; qd8_f16_qc8w_gemm_config.mr = 16; qd8_f16_qc8w_gemm_config.nr = 64; qd8_f16_qc8w_gemm_config.log2_kr = 2; } else #endif - #if XNN_ENABLE_AVX256VNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256vnni) { - qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni); - qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni); - qd8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni); - qd8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni); - qd8_f16_qc8w_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params; - qd8_f16_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qd8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm; - qd8_f16_qc8w_gemm_config.mr = 8; - qd8_f16_qc8w_gemm_config.nr = 8; - qd8_f16_qc8w_gemm_config.log2_kr = 3; - } else - #endif - #if XNN_ENABLE_AVXVNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { - // AVX VNNI checked before AVX512SKX as it performs better with VNNI microkernels - qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); - qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); - qd8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm); - qd8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm); - qd8_f16_qc8w_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params; - qd8_f16_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f16_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qd8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm; - qd8_f16_qc8w_gemm_config.mr = 5; - qd8_f16_qc8w_gemm_config.nr = 8; - qd8_f16_qc8w_gemm_config.log2_kr = 3; - } else - #endif #if XNN_ENABLE_AVX256SKX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { + qd8_f16_qc8w_gemm_config.arch = xnn_arch_x86_avx256skx; qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx); qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256skx); qd8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256skx); @@ -2297,6 +2287,7 @@ static void init_qd8_f16_qc8w_gemm_config(void) { } else #endif if (hardware_config->use_x86_avx2) { + qd8_f16_qc8w_gemm_config.arch = xnn_arch_x86_avx2; qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx2); qd8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(3)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avx2); qd8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx2); @@ -2307,6 +2298,237 @@ static void init_qd8_f16_qc8w_gemm_config(void) { qd8_f16_qc8w_gemm_config.log2_kr = 3; } #endif + assert(qd8_f16_qc8w_gemm_config.mr <= XNN_MAX_MR); + assert(qd8_f16_qc8w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); +} + +static void init_qdu8_f16_qc8w_gemm_config(void) { + // Use the same packing function throughout. + qdu8_f16_qc8w_gemm_config.pack_weights_and_biases = + (xnn_pack_weights_and_biases_fn)xnn_pack_qs8_weights_and_biases; + qdu8_f16_qc8w_gemm_config.packed_stride_weights_and_biases = + (xnn_packed_stride_weights_and_biases_fn) + xnn_packed_stride_qs8_weights_and_biases; + qdu8_f16_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qdu8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_gemm_goi_w; + #if XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_AVX256VNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256vnni) { + qdu8_f16_qc8w_gemm_config.arch = xnn_arch_x86_avx256vnni; + qdu8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni); + qdu8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni); + qdu8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni); + qdu8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni); + qdu8_f16_qc8w_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params; + qdu8_f16_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qdu8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm; + qdu8_f16_qc8w_gemm_config.mr = 8; + qdu8_f16_qc8w_gemm_config.nr = 8; + qdu8_f16_qc8w_gemm_config.log2_kr = 3; + } else + #endif + #if XNN_ENABLE_AVXVNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { + // AVX VNNI checked before AVX512SKX as it performs better with VNNI microkernels + qdu8_f16_qc8w_gemm_config.arch = xnn_arch_x86_avxvnni; + qdu8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); + qdu8_f16_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); + qdu8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm); + qdu8_f16_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm); + qdu8_f16_qc8w_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params; + qdu8_f16_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f16_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qdu8_f16_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm; + qdu8_f16_qc8w_gemm_config.mr = 5; + qdu8_f16_qc8w_gemm_config.nr = 8; + qdu8_f16_qc8w_gemm_config.log2_kr = 3; + } + #else + { + ; + } + #endif + assert(qdu8_f16_qc8w_gemm_config.mr <= XNN_MAX_MR); + assert(qdu8_f16_qc8w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); + #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +} + +static void init_qdu8_f32_qc8w_gemm_config(void) { + // Use the same packing function throughout. + qdu8_f32_qc8w_gemm_config.pack_weights_and_biases = + (xnn_pack_weights_and_biases_fn)xnn_pack_qs8_weights_and_biases; + qdu8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = + (xnn_packed_stride_weights_and_biases_fn) + xnn_packed_stride_qs8_weights_and_biases; + qdu8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_gemm_goi_w; + #if XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_AVX512VNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { + qdu8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avx512vnni; + qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); + qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm); + qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm); + qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm); + qdu8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + qdu8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + #if XNN_ENABLE_AVX256VNNI + qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm; + #else + qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar; + #endif + qdu8_f32_qc8w_gemm_config.mr = 10; + qdu8_f32_qc8w_gemm_config.nr = 16; + qdu8_f32_qc8w_gemm_config.log2_kr = 3; + } else + #endif + #if XNN_ENABLE_AVXVNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { + qdu8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avxvnni; + qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); + qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); + qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm); + qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm); + qdu8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + qdu8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm; + qdu8_f32_qc8w_gemm_config.mr = 5; + qdu8_f32_qc8w_gemm_config.nr = 8; + qdu8_f32_qc8w_gemm_config.log2_kr = 3; + } + #else + { + ; + } + #endif + assert(qdu8_f32_qc8w_gemm_config.mr <= XNN_MAX_MR); + assert(qdu8_f32_qc8w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); + #endif //XNN_ARCH_X86 || XNN_ARCH_X86_64 +} + +static void init_qdu8_f32_qc4w_gemm_config(void) { + // Use the same packing function throughout. + qdu8_f32_qc4w_gemm_config.pack_weights_and_biases = (xnn_pack_weights_and_biases_fn) xnn_pack_qs4_weights_and_biases; + qdu8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = (xnn_packed_stride_weights_and_biases_fn) xnn_packed_stride_qs4_weights_and_biases; + qdu8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4w_gemm_gio_w; // Ignored + qdu8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4w_gemm_goi_w; // Ignored + #if XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_AVX512VNNIGFNI && XNN_ENABLE_AVX256VNNI + // Zen4 has gfni but is slower and 8x16 works better on zen4. 14x16 is faster on Sapphire Rapids + // TODO(b/361288044): Re-enable once fixed. + if (false && !XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnnigfni && cpuinfo_get_core(0)->uarch != cpuinfo_uarch_zen4) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avx512vnnigfni; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(14)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm); + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 14; + qdu8_f32_qc4w_gemm_config.nr = 16; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } else + #endif // XNN_ENABLE_AVX512VNNIGFNI + #if XNN_ENABLE_AVX512VNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avx512vnni; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm); + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 8; + qdu8_f32_qc4w_gemm_config.nr = 16; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } else + #endif + #if XNN_ENABLE_AVXVNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avxvnni; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 5; + qdu8_f32_qc4w_gemm_config.nr = 8; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } else + #endif + #if XNN_ENABLE_AVX512SKX + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avx512skx; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm); + qdu8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; + qdu8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 8; + qdu8_f32_qc4w_gemm_config.nr = 16; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } else + #endif + #if XNN_ENABLE_AVX256SKX + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avx256skx; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm); + qdu8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; + qdu8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 8; + qdu8_f32_qc4w_gemm_config.nr = 8; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } else + #endif + if (hardware_config->use_x86_avx2) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_avx2; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm); + qdu8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; + qdu8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 4; + qdu8_f32_qc4w_gemm_config.nr = 8; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } else if (hardware_config->use_x86_ssse3) { + qdu8_f32_qc4w_gemm_config.arch = xnn_arch_x86_ssse3; + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd); + qdu8_f32_qc4w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd); + qdu8_f32_qc4w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qdu8_f32_qc4w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_gio_w; + qdu8_f32_qc4w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_qc4uw_gemm_goi_w; + qdu8_f32_qc4w_gemm_config.init.f32_qc4w = xnn_init_f32_qc4w_minmax_scalar_params; + qdu8_f32_qc4w_gemm_config.mr = 4; + qdu8_f32_qc4w_gemm_config.nr = 4; + qdu8_f32_qc4w_gemm_config.log2_kr = 3; + qdu8_f32_qc4w_gemm_config.planes = 2; + } + assert(qdu8_f32_qc4w_gemm_config.mr <= XNN_MAX_MR); + assert(qdu8_f32_qc4w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); + #endif //XNN_ARCH_X86 || XNN_ARCH_X86_64 } static void init_qd8_f32_qc8w_gemm_config(void) { @@ -2321,6 +2543,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { #if XNN_ENABLE_ASSEMBLY #if XNN_ENABLE_ARM_DOTPROD @@ -2410,6 +2633,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_ARM64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_PLATFORM_IOS || XNN_PLATFORM_MAC || XNN_PLATFORM_WINDOWS #if XNN_ENABLE_ASSEMBLY if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { @@ -2598,8 +2822,10 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512AMX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512amx) { + qd8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avx512amx; qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx); qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx); @@ -2608,51 +2834,15 @@ static void init_qd8_f32_qc8w_gemm_config(void) { qd8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. qd8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. qd8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qd8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar; + qd8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni_prfm; qd8_f32_qc8w_gemm_config.mr = 16; qd8_f32_qc8w_gemm_config.nr = 64; qd8_f32_qc8w_gemm_config.log2_kr = 2; } else #endif - #if XNN_ENABLE_AVX512VNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { - qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm); - qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm); - qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm); - qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm); - qd8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - qd8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - #if XNN_ENABLE_AVX256VNNI - qd8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm; - #else - qd8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar; - #endif - qd8_f32_qc8w_gemm_config.mr = 10; - qd8_f32_qc8w_gemm_config.nr = 16; - qd8_f32_qc8w_gemm_config.log2_kr = 3; - } else - #endif - #if XNN_ENABLE_AVXVNNI - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { - // AVX VNNI should be checked before AVX512SKX as it performs better with VNNI microkernels - qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm); - qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm); - qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm); - qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm); - qd8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - qd8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qd8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qd8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm; - qd8_f32_qc8w_gemm_config.mr = 5; - qd8_f32_qc8w_gemm_config.nr = 8; - qd8_f32_qc8w_gemm_config.log2_kr = 3; - } else - #endif #if XNN_ENABLE_AVX512SKX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { + qd8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avx512skx; qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512skx_prfm); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512skx_prfm); qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512skx_prfm); @@ -2665,6 +2855,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #endif #if XNN_ENABLE_AVX256SKX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { + qd8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avx256skx; qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(8)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256skx); qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256skx); @@ -2676,6 +2867,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { } else #endif if (hardware_config->use_x86_avx2) { + qd8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avx2; qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx2); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avx2); qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx2); @@ -2685,6 +2877,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { qd8_f32_qc8w_gemm_config.nr = 8; qd8_f32_qc8w_gemm_config.log2_kr = 3; } else if (hardware_config->use_x86_sse4_1) { + qd8_f32_qc8w_gemm_config.arch = xnn_arch_x86_sse4_1; qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__sse41_ld64); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__sse41_ld64); qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__sse41_ld64); @@ -2706,6 +2899,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_wasm_sdot) { if (hardware_config->is_x86) { qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__wasmsdot); @@ -2759,6 +2953,7 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_WASM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x2__scalar); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x2__scalar); @@ -2785,6 +2980,8 @@ static void init_qd8_f32_qc8w_gemm_config(void) { qd8_f32_qc8w_gemm_config.mr = 4; qd8_f32_qc8w_gemm_config.nr = 4; #endif + assert(qd8_f32_qc8w_gemm_config.mr <= XNN_MAX_MR); + assert(qd8_f32_qc8w_gemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1)); } static void init_qs8_qc8w_gemm_config(void) { @@ -2799,6 +2996,7 @@ static void init_qs8_qc8w_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { #if XNN_ENABLE_ASSEMBLY if (XNN_ENABLE_ARM_DOTPROD && hardware_config->use_arm_neon_dot) { @@ -3050,6 +3248,7 @@ static void init_qs8_qc8w_gemm_config(void) { #elif XNN_ARCH_ARM64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_PLATFORM_IOS || XNN_PLATFORM_MAC || XNN_PLATFORM_WINDOWS #if XNN_ENABLE_ASSEMBLY if (XNN_ENABLE_ARM_I8MM && hardware_config->use_arm_neon_i8mm) { @@ -3324,6 +3523,7 @@ static void init_qs8_qc8w_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512AMX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512amx) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx); @@ -3334,7 +3534,7 @@ static void init_qs8_qc8w_gemm_config(void) { qs8_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. qs8_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. qs8_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar; + qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni_prfm; qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w; qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w; qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w; @@ -3487,6 +3687,7 @@ static void init_qs8_qc8w_gemm_config(void) { #if XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_wasm_sdot) { if (hardware_config->is_x86) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__wasmsdot); @@ -3560,6 +3761,7 @@ static void init_qs8_qc8w_gemm_config(void) { #elif XNN_ARCH_WASM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x2__scalar_imagic); qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x2__scalar_imagic); @@ -3595,6 +3797,7 @@ static void init_qs8_qc8w_gemm_config(void) { qs8_qc8w_gemm_config.mr = 3; qs8_qc8w_gemm_config.nr = 4; #endif + assert(qs8_qc8w_gemm_config.mr <= XNN_MAX_MR); } static void init_qu8_gemm_config(void) { @@ -3609,6 +3812,7 @@ static void init_qu8_gemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->use_arm_neon) { #if XNN_ENABLE_ASSEMBLY switch (cpuinfo_get_uarch(0)->uarch) { @@ -3818,6 +4022,7 @@ static void init_qu8_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512SKX if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { qu8_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm); @@ -3880,6 +4085,7 @@ static void init_qu8_gemm_config(void) { #elif XNN_ARCH_WASM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); + (void) hardware_config; // May be unused. if (hardware_config->is_x86) { qu8_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qu8_gemm_minmax_fp32_ukernel_1x2__scalar_imagic); qu8_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qu8_gemm_minmax_fp32_ukernel_2x2__scalar_imagic); @@ -3906,6 +4112,7 @@ static void init_qu8_gemm_config(void) { qu8_gemm_config.mr = 3; qu8_gemm_config.nr = 4; #endif + assert(qu8_gemm_config.mr <= XNN_MAX_MR); } const struct xnn_gemm_config* xnn_init_f16_gemm_config() { @@ -3918,8 +4125,7 @@ const struct xnn_gemm_config* xnn_init_f16_gemm_config() { } const struct xnn_gemm_config* xnn_init_pf32_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(pf32_gemm); @@ -3927,8 +4133,7 @@ const struct xnn_gemm_config* xnn_init_pf32_gemm_config() { } const struct xnn_gemm_config* xnn_init_f32_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(f32_gemm); @@ -3936,8 +4141,7 @@ const struct xnn_gemm_config* xnn_init_f32_gemm_config() { } const struct xnn_gemm_config* xnn_init_f32_gemm_nr2_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(f32_gemm_nr2); @@ -3945,8 +4149,7 @@ const struct xnn_gemm_config* xnn_init_f32_gemm_nr2_config() { } const struct xnn_gemm_config* xnn_init_f32_qc4w_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(f32_qc4w_gemm); @@ -3954,8 +4157,7 @@ const struct xnn_gemm_config* xnn_init_f32_qc4w_gemm_config() { } const struct xnn_gemm_config* xnn_init_f32_qc8w_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(f32_qc8w_gemm); @@ -3976,10 +4178,24 @@ const struct xnn_gemm_config* xnn_init_qd8_f16_qc4w_gemm_config() { if (hardware_config == NULL || !xnn_is_f16_compatible_config(hardware_config)) { return NULL; } +#if (XNN_ARCH_X86 || XNN_ARCH_X86_64) && !XNN_PLATFORM_MOBILE + // there are no kernels on x86. qdu8_f16_qc4w kernels are used instead. + return NULL; +#endif + XNN_INIT_ONCE(qd8_f16_qc4w_gemm); return &qd8_f16_qc4w_gemm_config; } +const struct xnn_gemm_config* xnn_init_qdu8_f16_qc4w_gemm_config() { + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + if (hardware_config == NULL || !xnn_is_f16_compatible_config(hardware_config)) { + return NULL; + } + XNN_INIT_ONCE(qdu8_f16_qc4w_gemm); + return &qdu8_f16_qc4w_gemm_config; +} + const struct xnn_gemm_config* xnn_init_qd8_f16_qb4w_gemm_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL || !xnn_is_f16_compatible_config(hardware_config)) { @@ -3990,26 +4206,55 @@ const struct xnn_gemm_config* xnn_init_qd8_f16_qb4w_gemm_config() { } const struct xnn_gemm_config* xnn_init_qd8_f32_qc4w_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(qd8_f32_qc4w_gemm); return &qd8_f32_qc4w_gemm_config; } +const struct xnn_gemm_config* xnn_init_qdu8_f32_qc4w_gemm_config() { + if (xnn_init_hardware_config() == NULL) { + return NULL; + } + XNN_INIT_ONCE(qdu8_f32_qc4w_gemm); + return &qdu8_f32_qc4w_gemm_config; +} + const struct xnn_gemm_config* xnn_init_qd8_f32_qb4w_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(qd8_f32_qb4w_gemm); return &qd8_f32_qb4w_gemm_config; } +const struct xnn_gemm_config* xnn_init_qdu8_f32_qb4w_gemm_config() { + if (xnn_init_hardware_config() == NULL) { + return NULL; + } + XNN_INIT_ONCE(qdu8_f32_qb4w_gemm); + return &qdu8_f32_qb4w_gemm_config; +} + +const struct xnn_gemm_config* xnn_init_qdu8_f16_qc8w_gemm_config() { + if (xnn_init_hardware_config() == NULL) { + return NULL; + } + XNN_INIT_ONCE(qdu8_f16_qc8w_gemm); + return &qdu8_f16_qc8w_gemm_config; +} + +const struct xnn_gemm_config* xnn_init_qdu8_f32_qc8w_gemm_config() { + if (xnn_init_hardware_config() == NULL) { + return NULL; + } + XNN_INIT_ONCE(qdu8_f32_qc8w_gemm); + return &qdu8_f32_qc8w_gemm_config; +} + const struct xnn_gemm_config* xnn_init_qd8_f32_qc8w_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(qd8_f32_qc8w_gemm); @@ -4017,12 +4262,10 @@ const struct xnn_gemm_config* xnn_init_qd8_f32_qc8w_gemm_config() { } const struct xnn_gemm_config* xnn_init_qp8_f32_qc4w_gemm_config() { - const struct xnn_hardware_config* hardware_config = - xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } -XNN_INIT_ONCE(qp8_f32_qc4w_gemm); + XNN_INIT_ONCE(qp8_f32_qc4w_gemm); // Only return the config pointer if it actually provides a kernel. if (qp8_f32_qc4w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) { return &qp8_f32_qc4w_gemm_config; @@ -4030,6 +4273,18 @@ XNN_INIT_ONCE(qp8_f32_qc4w_gemm); return NULL; } +const struct xnn_gemm_config* xnn_init_qp8_f32_qc8w_gemm_config() { + if (xnn_init_hardware_config() == NULL) { + return NULL; + } + XNN_INIT_ONCE(qp8_f32_qc8w_gemm); + // Only return the config pointer if it actually provides a kernel. + if (qp8_f32_qc8w_gemm_config.minmax.qp8gemm[0].function[0] != NULL) { + return &qp8_f32_qc8w_gemm_config; + } + return NULL; +} + const struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -4045,8 +4300,7 @@ XNN_INIT_ONCE(qp8_f32_qb4w_gemm); } const struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(qs8_qc8w_gemm); @@ -4054,8 +4308,7 @@ const struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config() { } const struct xnn_gemm_config* xnn_init_qu8_gemm_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { + if (xnn_init_hardware_config() == NULL) { return NULL; } XNN_INIT_ONCE(qu8_gemm); diff --git a/src/configs/hardware-config.c b/src/configs/hardware-config.c index 100186c2adfc..db79bc9ad076 100644 --- a/src/configs/hardware-config.c +++ b/src/configs/hardware-config.c @@ -5,6 +5,10 @@ #include +#if XNN_ENABLE_CPUINFO +#include +#endif // XNN_ENABLE_CPUINFO + #include "xnnpack/common.h" #if _WIN32 @@ -178,7 +182,7 @@ static void init_hardware_config(void) { #endif #if XNN_ENABLE_AVX256VNNI // Using cpuinfo_has_x86_amx_int8 as placeholder for cpuinfo_has_x86_avx10 - hardware_config.use_x86_avx256vnni = (hardware_config.use_x86_avx512skx && cpuinfo_has_x86_avxvnni()) || cpuinfo_has_x86_amx_int8(); + hardware_config.use_x86_avx256vnni = (hardware_config.use_x86_avx512skx && cpuinfo_has_x86_avx512vnni()) || cpuinfo_has_x86_amx_int8(); #else hardware_config.use_x86_avx256vnni = 0; #endif @@ -371,6 +375,59 @@ static void init_hardware_config(void) { if (hardware_config.use_hvx) hardware_config.arch_flags |= xnn_arch_hvx; #endif // XNN_ARCH_HEXAGON +#if XNN_ENABLE_CPUINFO + // Set the size of the L1 and L2 data caches. + if (!cpuinfo_initialize()) { + xnn_log_warning( + "Failed to initialize cpuinfo, unable to determine L1/L2 data cache " + "properties."); + } else { + const struct cpuinfo_processor* proc_info = cpuinfo_get_processor(0); + if (proc_info != NULL) { + // Get the L1 cache information. + const struct cpuinfo_cache* l1_data_cache = proc_info->cache.l1d; + if (l1_data_cache != NULL) { + hardware_config.l1_data_cache_bytes = l1_data_cache->size; + hardware_config.l1_data_cache_line_size = l1_data_cache->line_size; + hardware_config.l1_data_cache_associativity = + l1_data_cache->associativity; + hardware_config.l1_data_cache_num_sets = l1_data_cache->sets; + xnn_log_debug( + "l1_data_cache_bytes=%zu, l1_data_cache_line_size=%zu, " + "l1_data_cache_associativity=%zu, l1_data_cache_num_sets=%zu.\n", + hardware_config.l1_data_cache_bytes, + hardware_config.l1_data_cache_line_size, + hardware_config.l1_data_cache_associativity, + hardware_config.l1_data_cache_num_sets); + } else { + xnn_log_warning("Unable to determine L1 data cache properties."); + } + + // Get the L2 cache information. + const struct cpuinfo_cache* l2_data_cache = proc_info->cache.l2; + if (l2_data_cache != NULL) { + hardware_config.l2_data_cache_bytes = l2_data_cache->size; + hardware_config.l2_data_cache_line_size = l2_data_cache->line_size; + hardware_config.l2_data_cache_associativity = + l2_data_cache->associativity; + hardware_config.l2_data_cache_num_sets = l2_data_cache->sets; + xnn_log_debug( + "l2_data_cache_bytes=%zu, l2_data_cache_line_size=%zu, " + "l2_data_cache_associativity=%zu, l2_data_cache_num_sets=%zu.\n", + hardware_config.l2_data_cache_bytes, + hardware_config.l2_data_cache_line_size, + hardware_config.l2_data_cache_associativity, + hardware_config.l2_data_cache_num_sets); + } else { + xnn_log_warning("Unable to determine L2 data cache properties."); + } + } else { + xnn_log_warning("Unable to determine L1/L2 data cache properties."); + } + } +#else + xnn_log_warning("Unable to determine L1/L2 data cache properties."); +#endif // XNN_ENABLE_CPUINFO } const struct xnn_hardware_config* xnn_init_hardware_config() { diff --git a/src/configs/unary-elementwise-config.c b/src/configs/unary-elementwise-config.c index 496d7c3aba9b..55feaacb66fa 100644 --- a/src/configs/unary-elementwise-config.c +++ b/src/configs/unary-elementwise-config.c @@ -65,7 +65,6 @@ static struct xnn_unary_elementwise_config qs8_to_f32_cvt_config = {0}; static struct xnn_unary_elementwise_config qu8_cvt_config = {0}; static struct xnn_unary_elementwise_config qu8_lrelu_config = {0}; static struct xnn_unary_elementwise_config qu8_to_f32_cvt_config = {0}; -static struct xnn_unary_elementwise_config s32_to_f32_cvt_config = {0}; static struct xnn_unary_elementwise_config s8_clamp_config = {0}; static struct xnn_unary_elementwise_config u8_clamp_config = {0}; static struct xnn_unary_elementwise_config xx_copy_config = {0}; @@ -119,7 +118,6 @@ XNN_INIT_ONCE_GUARD(qs8_to_f32_cvt); XNN_INIT_ONCE_GUARD(qu8_cvt); XNN_INIT_ONCE_GUARD(qu8_lrelu); XNN_INIT_ONCE_GUARD(qu8_to_f32_cvt); -XNN_INIT_ONCE_GUARD(s32_to_f32_cvt); XNN_INIT_ONCE_GUARD(s8_clamp); XNN_INIT_ONCE_GUARD(u8_clamp); XNN_INIT_ONCE_GUARD(xx_copy); @@ -1467,42 +1465,6 @@ static void init_f32_to_qu8_cvt_config(void) { #endif } -static void init_s32_to_f32_cvt_config(void) { - #if XNN_ARCH_ARM - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - assert(hardware_config != NULL); - if (hardware_config->use_arm_neon) { - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__neon_u16; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - } else { - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__scalar_u4; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - } - #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - assert(hardware_config != NULL); - #if XNN_ENABLE_AVX512F - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) { - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__avx512f_u64; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - } else - #endif - if (hardware_config->use_x86_avx2) { - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__avx2_u32; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - } else { - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__scalar_u4; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - } - #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__wasmsimd_u16; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - #else - s32_to_f32_cvt_config.ukernel = (xnn_vunary_ukernel_fn) xnn_s32_f32_vcvt_ukernel__scalar_u4; - s32_to_f32_cvt_config.init = (xnn_init_unary_uparams_fn) xnn_init_s32_f32_cvt_scalar_params; - #endif -} - static void init_qs8_cvt_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -2374,15 +2336,6 @@ const struct xnn_unary_elementwise_config* xnn_init_f32_to_qu8_cvt_config() { return &f32_to_qu8_cvt_config; } -const struct xnn_unary_elementwise_config* xnn_init_s32_to_f32_cvt_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { - return NULL; - } - XNN_INIT_ONCE(s32_to_f32_cvt); - return &s32_to_f32_cvt_config; -} - const struct xnn_unary_elementwise_config* xnn_init_qs8_cvt_config() { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); if (hardware_config == NULL) { diff --git a/src/configs/zip-config.c b/src/configs/zip-config.c deleted file mode 100644 index 4a5d3044e6a5..000000000000 --- a/src/configs/zip-config.c +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2023 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include - -#include "xnnpack/common.h" -#include "xnnpack/config.h" -#include "xnnpack/init-once.h" -#include "xnnpack/microfnptr.h" -#include "xnnpack/zip.h" - -static struct xnn_zip_config x8_zip_config = {0}; -static struct xnn_zip_config x32_zip_config = {0}; - -XNN_INIT_ONCE_GUARD(x8_zip); -XNN_INIT_ONCE_GUARD(x32_zip); - -static void init_x8_zip_config(void) { - #if XNN_ARCH_ARM - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - assert(hardware_config != NULL); - if (hardware_config->use_arm_neon) { - x8_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x2_ukernel__neon; - x8_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x3_ukernel__neon; - x8_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x4_ukernel__neon; - x8_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x8_zip_xm_ukernel__neon; - } else if (!XNN_PLATFORM_MOBILE) { - x8_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x2_ukernel__scalar; - x8_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x3_ukernel__scalar; - x8_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x4_ukernel__scalar; - x8_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x8_zip_xm_ukernel__scalar; - } - #elif XNN_ARCH_ARM64 - x8_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x2_ukernel__neon; - x8_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x3_ukernel__neon; - x8_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x4_ukernel__neon; - x8_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x8_zip_xm_ukernel__neon; - #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 - x8_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x2_ukernel__sse2; - x8_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x3_ukernel__sse2; - x8_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x4_ukernel__sse2; - x8_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x8_zip_xm_ukernel__sse2; - #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD - x8_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x2_ukernel__scalar; - x8_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x3_ukernel__scalar; - x8_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x4_ukernel__scalar; - x8_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x8_zip_xm_ukernel__scalar; - #else - x8_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x2_ukernel__scalar; - x8_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x3_ukernel__scalar; - x8_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x8_zip_x4_ukernel__scalar; - x8_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x8_zip_xm_ukernel__scalar; - #endif - -} - -static void init_x32_zip_config(void) { - #if XNN_ARCH_ARM - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - assert(hardware_config != NULL); - if (hardware_config->use_arm_neon) { - x32_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x2_ukernel__neon; - x32_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x3_ukernel__neon; - x32_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x4_ukernel__neon; - x32_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x32_zip_xm_ukernel__neon; - } else if (!XNN_PLATFORM_MOBILE) { - x32_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x2_ukernel__scalar; - x32_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x3_ukernel__scalar; - x32_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x4_ukernel__scalar; - x32_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x32_zip_xm_ukernel__scalar; - } - #elif XNN_ARCH_ARM64 - x32_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x2_ukernel__neon; - x32_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x3_ukernel__neon; - x32_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x4_ukernel__neon; - x32_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x32_zip_xm_ukernel__neon; - #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 - x32_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x2_ukernel__sse2; - x32_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x3_ukernel__sse2; - x32_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x4_ukernel__sse2; - x32_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x32_zip_xm_ukernel__sse2; - #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD - x32_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x2_ukernel__wasmsimd; - x32_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x3_ukernel__wasmsimd; - x32_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x4_ukernel__wasmsimd; - x32_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x32_zip_xm_ukernel__wasmsimd; - #else - x32_zip_config.x2 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x2_ukernel__scalar; - x32_zip_config.x3 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x3_ukernel__scalar; - x32_zip_config.x4 = (xnn_zipc_ukernel_fn) xnn_x32_zip_x4_ukernel__scalar; - x32_zip_config.xm = (xnn_zipv_ukernel_fn) xnn_x32_zip_xm_ukernel__scalar; - #endif - -} - -const struct xnn_zip_config* xnn_init_x8_zip_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { - return NULL; - } - XNN_INIT_ONCE(x8_zip); - return &x8_zip_config; -} - -const struct xnn_zip_config* xnn_init_x32_zip_config() { - const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); - if (hardware_config == NULL) { - return NULL; - } - XNN_INIT_ONCE(x32_zip); - return &x32_zip_config; -} diff --git a/src/datatype.c b/src/datatype.c index 16dfb7465fc4..1a00c7e58637 100644 --- a/src/datatype.c +++ b/src/datatype.c @@ -22,6 +22,7 @@ bool xnn_datatype_is_real(enum xnn_datatype t) { case xnn_datatype_qcint32: case xnn_datatype_qcint4: case xnn_datatype_qdint8: + case xnn_datatype_qduint8: case xnn_datatype_qpint8: case xnn_datatype_qbint4: case xnn_datatype_pfp32: @@ -44,6 +45,7 @@ bool xnn_datatype_is_integral(enum xnn_datatype t) { case xnn_datatype_qcint32: case xnn_datatype_qcint4: case xnn_datatype_qdint8: + case xnn_datatype_qduint8: case xnn_datatype_qpint8: case xnn_datatype_qbint4: case xnn_datatype_pfp32: @@ -64,6 +66,7 @@ bool xnn_datatype_is_quantized(enum xnn_datatype t) { case xnn_datatype_qcint32: case xnn_datatype_qcint4: case xnn_datatype_qdint8: + case xnn_datatype_qduint8: case xnn_datatype_qpint8: case xnn_datatype_qbint4: return true; @@ -91,6 +94,7 @@ size_t xnn_datatype_log2_size_bits(enum xnn_datatype t) { case xnn_datatype_quint8: case xnn_datatype_qcint8: case xnn_datatype_qdint8: + case xnn_datatype_qduint8: case xnn_datatype_qpint8: return 3; case xnn_datatype_fp16: @@ -137,6 +141,7 @@ bool xnn_datatype_is_byte_addressable(enum xnn_datatype t) { case xnn_datatype_qcint8: case xnn_datatype_qcint32: case xnn_datatype_qdint8: + case xnn_datatype_qduint8: case xnn_datatype_int32: case xnn_datatype_fp32: return true; diff --git a/src/enums/datatype-strings.c b/src/enums/datatype-strings.c index e0f5368473b4..1d2657451a2e 100644 --- a/src/enums/datatype-strings.c +++ b/src/enums/datatype-strings.c @@ -38,6 +38,8 @@ const char* xnn_datatype_to_string(enum xnn_datatype type) { return "QCINT32"; case xnn_datatype_qdint8: return "QDINT8"; + case xnn_datatype_qduint8: + return "QDUINT8"; case xnn_datatype_qpint8: return "QPINT8"; case xnn_datatype_int32: diff --git a/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u32.c b/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u32.c index d86d2489cb41..dbc2c81b2a9d 100644 --- a/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u32.c +++ b/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u32.c @@ -56,7 +56,7 @@ void xnn_f16_vprelu_ukernel__avx512fp16_u32( const __m512h va = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a)); const __mmask32 vsign = _mm512_cmp_ph_mask(va, vzero, _CMP_LT_OQ); - __m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_loadu_ph(b)); + __m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b))); _mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc)); } diff --git a/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u64.c b/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u64.c index 31720e3bb675..76ad24ade7ab 100644 --- a/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u64.c +++ b/src/f16-vbinary/gen/f16-vprelu-avx512fp16-u64.c @@ -71,7 +71,7 @@ void xnn_f16_vprelu_ukernel__avx512fp16_u64( const __m512h va = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, a)); const __mmask32 vsign = _mm512_cmp_ph_mask(va, vzero, _CMP_LT_OQ); - __m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_loadu_ph(b)); + __m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b))); _mm512_mask_storeu_epi16(o, vmask, _mm512_castph_si512(vacc)); } diff --git a/src/f16-vbinary/vop-avx512fp16.c.in b/src/f16-vbinary/vop-avx512fp16.c.in index df8c8e55ea0c..1d06a6bd2d9b 100644 --- a/src/f16-vbinary/vop-avx512fp16.c.in +++ b/src/f16-vbinary/vop-avx512fp16.c.in @@ -113,7 +113,7 @@ void xnn_f16_v${OP.lower()}_ukernel__avx512fp16_u${BATCH_TILE}( $if OP == "PRELU": const __mmask32 vsign = _mm512_cmp_ph_mask(va, vzero, _CMP_LT_OQ); - __m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_loadu_ph(b)); + __m512h vacc = _mm512_mask_mul_ph(va, vsign, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b))); $else: __m512h vacc = ${_MM512_MASKZ_OP_ph}(vmask, va, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(vmask, b))); diff --git a/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..f09b6b989a6a --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,298 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + vmovups [r8], zmm20 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..13f7f2fc41c1 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,361 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm21, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + vmovaps zmm22, zmm21 + vmovaps zmm23, zmm21 + vmovaps zmm24, zmm21 + vmovaps zmm25, zmm21 + vmovaps zmm26, zmm21 + vmovaps zmm27, zmm21 + vmovaps zmm28, zmm21 + vmovaps zmm29, zmm21 + vmovaps zmm30, zmm21 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm27, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm28, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm29, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + vfmadd231ps zmm30, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm21 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm22 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm23 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm24 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm25 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm26 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm27 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm28 + vmovups [rbp], zmm19 + vmovups [rbp + 64], zmm29 + vmovups [r8], zmm20 + vmovups [r8 + 64], zmm30 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm30 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..09b78ddbf9ed --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,321 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + vmovaps zmm21, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rdi + r11] + vfmadd231ps zmm21, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + vmovups [r8], zmm20 + vmovups [rdi], zmm21 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + add rdi, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [rdi]{k1}, zmm21 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..576ef75275fd --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,390 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm22, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm20, zmm7 + vmovaps zmm21, zmm7 + vmovaps zmm23, zmm22 + vmovaps zmm24, zmm22 + vmovaps zmm25, zmm22 + vmovaps zmm26, zmm22 + vmovaps zmm27, zmm22 + vmovaps zmm28, zmm22 + vmovaps zmm29, zmm22 + vmovaps zmm30, zmm22 + vmovaps zmm12, zmm22 + vmovaps zmm13, zmm22 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm27, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm28, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm29, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm30, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rdi + r11] + vfmadd231ps zmm21, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm22 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm23 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm24 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm25 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm26 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm27 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm28 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm29 + vmovups [rbp], zmm19 + vmovups [rbp + 64], zmm30 + vmovups [r8], zmm20 + vmovups [r8 + 64], zmm12 + vmovups [rdi], zmm21 + vmovups [rdi + 64], zmm13 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + add rdi, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm30 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [rdi]{k1}, zmm21 + vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm13 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..a20251cc6419 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,96 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q12, [x5, 0] + ldp q13, q14, [x5, 32] + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v8.4s, v2.s[0] + fmla v13.4s, v9.4s, v2.s[0] + fmla v14.4s, v10.4s, v2.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q12, [x6], 32 + stp q13, q14, [x6], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q12, [x6], 32 + mov v11.16b, v13.16b + mov v12.16b, v14.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + mov v11.16b, v12.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + dup d11, v11.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..1dfd045b8228 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,78 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vmaxps zmm7, zmm0, zmm7 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + add r10, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..85b0958e4d38 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,87 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm8, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm8 + add r10, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..9b3b2acd90a4 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,106 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm14, [r9 + 192] + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm8, zmm2, zmm11 + vfmadd231ps zmm9, zmm2, zmm12 + vfmadd231ps zmm14, zmm2, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm8 + vmovups [r10 + 128], zmm9 + vmovups [r10 + 192], zmm14 + add r10, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm9 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm14 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..3175bb6408d4 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,80 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q12, [x5, 0] + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v8.4s, v2.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q12, [x6], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + mov v11.16b, v12.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + dup d11, v11.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..6c361d469337 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,130 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q13, [x5, 0] + ldp q15, q17, [x5, 32] + mov v12.16b, v11.16b + mov v14.16b, v13.16b + mov v16.16b, v15.16b + mov v18.16b, v17.16b + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v8.4s, v2.s[0] + fmla v14.4s, v8.4s, v3.s[0] + fmla v15.4s, v9.4s, v2.s[0] + fmla v16.4s, v9.4s, v3.s[0] + fmla v17.4s, v10.4s, v2.s[0] + fmla v18.4s, v10.4s, v3.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q13, [x6], 32 + stp q15, q17, [x6], 32 + stp q12, q14, [x13], 32 + stp q16, q18, [x13], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q13, [x6], 32 + stp q12, q14, [x13], 32 + mov v11.16b, v15.16b + mov v13.16b, v17.16b + mov v12.16b, v16.16b + mov v14.16b, v18.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + mov v11.16b, v13.16b + mov v12.16b, v14.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..f429e9631c6f --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,95 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + add r10, 64 + add r13, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..212db9528b5f --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,110 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm9, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm14, zmm9 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm9, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm14, zmm3, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm9 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm14 + add r10, 128 + add r13, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..d79b807966c0 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,141 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm9, [r9 + 64] + vmovaps zmm15, [r9 + 128] + vmovaps zmm17, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm14, zmm9 + vmovaps zmm16, zmm15 + vmovaps zmm18, zmm17 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm9, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm12 + vfmadd231ps zmm17, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm14, zmm3, zmm11 + vfmadd231ps zmm16, zmm3, zmm12 + vfmadd231ps zmm18, zmm3, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm9 + vmovups [r10 + 128], zmm15 + vmovups [r10 + 192], zmm17 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm14 + vmovups [r13 + 128], zmm16 + vmovups [r13 + 192], zmm18 + add r10, 256 + add r13, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm15 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm16 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm18 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..1749ed7217d4 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,102 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q13, [x5, 0] + mov v12.16b, v11.16b + mov v14.16b, v13.16b + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v8.4s, v2.s[0] + fmla v14.4s, v8.4s, v3.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q13, [x6], 32 + stp q12, q14, [x13], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + mov v11.16b, v13.16b + mov v12.16b, v14.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..a963c5fbab5b --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,161 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q14, [x5, 0] + ldp q17, q20, [x5, 32] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v15.16b, v14.16b + mov v16.16b, v14.16b + mov v18.16b, v17.16b + mov v19.16b, v17.16b + mov v21.16b, v20.16b + mov v22.16b, v20.16b + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v8.4s, v2.s[0] + fmla v15.4s, v8.4s, v3.s[0] + fmla v16.4s, v8.4s, v4.s[0] + fmla v17.4s, v9.4s, v2.s[0] + fmla v18.4s, v9.4s, v3.s[0] + fmla v19.4s, v9.4s, v4.s[0] + fmla v20.4s, v10.4s, v2.s[0] + fmla v21.4s, v10.4s, v3.s[0] + fmla v22.4s, v10.4s, v4.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q14, [x6], 32 + stp q17, q20, [x6], 32 + stp q12, q15, [x13], 32 + stp q18, q21, [x13], 32 + stp q13, q16, [x14], 32 + stp q19, q22, [x14], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q14, [x6], 32 + stp q12, q15, [x13], 32 + stp q13, q16, [x14], 32 + mov v11.16b, v17.16b + mov v14.16b, v20.16b + mov v12.16b, v18.16b + mov v15.16b, v21.16b + mov v13.16b, v19.16b + mov v16.16b, v22.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + mov v11.16b, v14.16b + mov v12.16b, v15.16b + mov v13.16b, v16.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..dc65d2c61afe --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,112 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + vmovups [rbx], zmm9 + add r10, 64 + add r13, 64 + add rbx, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..92db444e85ac --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,133 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm14, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm15, zmm14 + vmovaps zmm16, zmm14 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm15, zmm3, zmm11 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm16, zmm4, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm14 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm15 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm16 + add r10, 128 + add r13, 128 + add rbx, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..ca88fdb48066 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,176 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm14, [r9 + 64] + vmovaps zmm17, [r9 + 128] + vmovaps zmm20, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm15, zmm14 + vmovaps zmm16, zmm14 + vmovaps zmm18, zmm17 + vmovaps zmm19, zmm17 + vmovaps zmm21, zmm20 + vmovaps zmm22, zmm20 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm12 + vfmadd231ps zmm20, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm15, zmm3, zmm11 + vfmadd231ps zmm18, zmm3, zmm12 + vfmadd231ps zmm21, zmm3, zmm13 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm16, zmm4, zmm11 + vfmadd231ps zmm19, zmm4, zmm12 + vfmadd231ps zmm22, zmm4, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm14 + vmovups [r10 + 128], zmm17 + vmovups [r10 + 192], zmm20 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm15 + vmovups [r13 + 128], zmm18 + vmovups [r13 + 192], zmm21 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm16 + vmovups [rbx + 128], zmm19 + vmovups [rbx + 192], zmm22 + add r10, 256 + add r13, 256 + add rbx, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm17 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm20 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm18 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm21 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm19 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm22 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..195fd2bcbfed --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,121 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q14, [x5, 0] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v15.16b, v14.16b + mov v16.16b, v14.16b + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v8.4s, v2.s[0] + fmla v15.4s, v8.4s, v3.s[0] + fmla v16.4s, v8.4s, v4.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q14, [x6], 32 + stp q12, q15, [x13], 32 + stp q13, q16, [x14], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + mov v11.16b, v14.16b + mov v12.16b, v15.16b + mov v13.16b, v16.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..f91ba2bdde1b --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,194 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q15, [x5, 0] + ldp q19, q23, [x5, 32] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v16.16b, v15.16b + mov v17.16b, v15.16b + mov v18.16b, v15.16b + mov v20.16b, v19.16b + mov v21.16b, v19.16b + mov v22.16b, v19.16b + mov v24.16b, v23.16b + mov v25.16b, v23.16b + mov v26.16b, v23.16b + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldr d5, [x11, x20] + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v8.4s, v2.s[0] + fmla v16.4s, v8.4s, v3.s[0] + fmla v17.4s, v8.4s, v4.s[0] + fmla v18.4s, v8.4s, v5.s[0] + fmla v19.4s, v9.4s, v2.s[0] + fmla v20.4s, v9.4s, v3.s[0] + fmla v21.4s, v9.4s, v4.s[0] + fmla v22.4s, v9.4s, v5.s[0] + fmla v23.4s, v10.4s, v2.s[0] + fmla v24.4s, v10.4s, v3.s[0] + fmla v25.4s, v10.4s, v4.s[0] + fmla v26.4s, v10.4s, v5.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmin v24.4s, v1.4s, v24.4s + fmin v25.4s, v1.4s, v25.4s + fmin v26.4s, v1.4s, v26.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + fmax v24.4s, v0.4s, v24.4s + fmax v25.4s, v0.4s, v25.4s + fmax v26.4s, v0.4s, v26.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q15, [x6], 32 + stp q19, q23, [x6], 32 + stp q12, q16, [x13], 32 + stp q20, q24, [x13], 32 + stp q13, q17, [x14], 32 + stp q21, q25, [x14], 32 + stp q14, q18, [x15], 32 + stp q22, q26, [x15], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q15, [x6], 32 + stp q12, q16, [x13], 32 + stp q13, q17, [x14], 32 + stp q14, q18, [x15], 32 + mov v11.16b, v19.16b + mov v15.16b, v23.16b + mov v12.16b, v20.16b + mov v16.16b, v24.16b + mov v13.16b, v21.16b + mov v17.16b, v25.16b + mov v14.16b, v22.16b + mov v18.16b, v26.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + mov v11.16b, v15.16b + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..71d442f0da79 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,129 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + vmovups [rbx], zmm9 + vmovups [rbp], zmm14 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..7f117407c4ad --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,156 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm15, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm16, zmm15 + vmovaps zmm17, zmm15 + vmovaps zmm18, zmm15 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm16, zmm3, zmm11 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm17, zmm4, zmm11 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm18, zmm5, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm15 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm16 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm17 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm18 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..e862b8a36de9 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,211 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm15, [r9 + 64] + vmovaps zmm19, [r9 + 128] + vmovaps zmm23, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm16, zmm15 + vmovaps zmm17, zmm15 + vmovaps zmm18, zmm15 + vmovaps zmm20, zmm19 + vmovaps zmm21, zmm19 + vmovaps zmm22, zmm19 + vmovaps zmm24, zmm23 + vmovaps zmm25, zmm23 + vmovaps zmm26, zmm23 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm11 + vfmadd231ps zmm19, zmm2, zmm12 + vfmadd231ps zmm23, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm16, zmm3, zmm11 + vfmadd231ps zmm20, zmm3, zmm12 + vfmadd231ps zmm24, zmm3, zmm13 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm17, zmm4, zmm11 + vfmadd231ps zmm21, zmm4, zmm12 + vfmadd231ps zmm25, zmm4, zmm13 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm18, zmm5, zmm11 + vfmadd231ps zmm22, zmm5, zmm12 + vfmadd231ps zmm26, zmm5, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm15 + vmovups [r10 + 128], zmm19 + vmovups [r10 + 192], zmm23 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm16 + vmovups [r13 + 128], zmm20 + vmovups [r13 + 192], zmm24 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm17 + vmovups [rbx + 128], zmm21 + vmovups [rbx + 192], zmm25 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm18 + vmovups [rbp + 128], zmm22 + vmovups [rbp + 192], zmm26 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm19 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm23 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm20 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm24 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm21 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm25 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm26 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..31d3ebf78e99 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,142 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q15, [x5, 0] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v16.16b, v15.16b + mov v17.16b, v15.16b + mov v18.16b, v15.16b + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldr d5, [x11, x20] + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v8.4s, v2.s[0] + fmla v16.4s, v8.4s, v3.s[0] + fmla v17.4s, v8.4s, v4.s[0] + fmla v18.4s, v8.4s, v5.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q15, [x6], 32 + stp q12, q16, [x13], 32 + stp q13, q17, [x14], 32 + stp q14, q18, [x15], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + mov v11.16b, v15.16b + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..8f5d12c8e6c9 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,225 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x12, x11, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + add x19, x15, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + csel x12, x11, x12, LS + csel x19, x15, x19, LS + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q16, [x5, 0] + ldp q21, q26, [x5, 32] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v15.16b, v11.16b + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v22.16b, v21.16b + mov v23.16b, v21.16b + mov v24.16b, v21.16b + mov v25.16b, v21.16b + mov v27.16b, v26.16b + mov v28.16b, v26.16b + mov v29.16b, v26.16b + mov v30.16b, v26.16b + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldr d5, [x11, x20] + ldr d6, [x12, x20] + ldp q7, q8, [x5], 32 + ldp q9, q10, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v7.4s, v6.s[0] + fmla v16.4s, v8.4s, v2.s[0] + fmla v17.4s, v8.4s, v3.s[0] + fmla v18.4s, v8.4s, v4.s[0] + fmla v19.4s, v8.4s, v5.s[0] + fmla v20.4s, v8.4s, v6.s[0] + fmla v21.4s, v9.4s, v2.s[0] + fmla v22.4s, v9.4s, v3.s[0] + fmla v23.4s, v9.4s, v4.s[0] + fmla v24.4s, v9.4s, v5.s[0] + fmla v25.4s, v9.4s, v6.s[0] + fmla v26.4s, v10.4s, v2.s[0] + fmla v27.4s, v10.4s, v3.s[0] + fmla v28.4s, v10.4s, v4.s[0] + fmla v29.4s, v10.4s, v5.s[0] + fmla v30.4s, v10.4s, v6.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmin v24.4s, v1.4s, v24.4s + fmin v25.4s, v1.4s, v25.4s + fmin v26.4s, v1.4s, v26.4s + fmin v27.4s, v1.4s, v27.4s + fmin v28.4s, v1.4s, v28.4s + fmin v29.4s, v1.4s, v29.4s + fmin v30.4s, v1.4s, v30.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + fmax v24.4s, v0.4s, v24.4s + fmax v25.4s, v0.4s, v25.4s + fmax v26.4s, v0.4s, v26.4s + fmax v27.4s, v0.4s, v27.4s + fmax v28.4s, v0.4s, v28.4s + fmax v29.4s, v0.4s, v29.4s + fmax v30.4s, v0.4s, v30.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q11, q16, [x6], 32 + stp q21, q26, [x6], 32 + stp q12, q17, [x13], 32 + stp q22, q27, [x13], 32 + stp q13, q18, [x14], 32 + stp q23, q28, [x14], 32 + stp q14, q19, [x15], 32 + stp q24, q29, [x15], 32 + stp q15, q20, [x19], 32 + stp q25, q30, [x19], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q11, q16, [x6], 32 + stp q12, q17, [x13], 32 + stp q13, q18, [x14], 32 + stp q14, q19, [x15], 32 + stp q15, q20, [x19], 32 + mov v11.16b, v21.16b + mov v16.16b, v26.16b + mov v12.16b, v22.16b + mov v17.16b, v27.16b + mov v13.16b, v23.16b + mov v18.16b, v28.16b + mov v14.16b, v24.16b + mov v19.16b, v29.16b + mov v15.16b, v25.16b + mov v20.16b, v30.16b + + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + str q15, [x19], 16 + mov v11.16b, v16.16b + mov v12.16b, v17.16b + mov v13.16b, v18.16b + mov v14.16b, v19.16b + mov v15.16b, v20.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + str d15, [x19], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + str s15, [x19] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..b6743cce4672 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,146 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vbroadcastss zmm6, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm7 + vmovups [r13], zmm8 + vmovups [rbx], zmm9 + vmovups [rbp], zmm14 + vmovups [r8], zmm15 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..c0273cc2f80d --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,179 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm16, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm17, zmm16 + vmovaps zmm18, zmm16 + vmovaps zmm19, zmm16 + vmovaps zmm20, zmm16 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm11 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm17, zmm3, zmm11 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm18, zmm4, zmm11 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm19, zmm5, zmm11 + vbroadcastss zmm6, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm10 + vfmadd231ps zmm20, zmm6, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm16 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm17 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm18 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm19 + vmovups [r8], zmm15 + vmovups [r8 + 64], zmm20 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm20 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..8a2a511a2474 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,246 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm16, [r9 + 64] + vmovaps zmm21, [r9 + 128] + vmovaps zmm26, [r9 + 192] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm17, zmm16 + vmovaps zmm18, zmm16 + vmovaps zmm19, zmm16 + vmovaps zmm20, zmm16 + vmovaps zmm22, zmm21 + vmovaps zmm23, zmm21 + vmovaps zmm24, zmm21 + vmovaps zmm25, zmm21 + vmovaps zmm27, zmm26 + vmovaps zmm28, zmm26 + vmovaps zmm29, zmm26 + vmovaps zmm30, zmm26 + add r9, 256 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm12, [r9 + 128] + vmovaps zmm13, [r9 + 192] + add r9, 256 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm11 + vfmadd231ps zmm21, zmm2, zmm12 + vfmadd231ps zmm26, zmm2, zmm13 + vbroadcastss zmm3, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm17, zmm3, zmm11 + vfmadd231ps zmm22, zmm3, zmm12 + vfmadd231ps zmm27, zmm3, zmm13 + vbroadcastss zmm4, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm18, zmm4, zmm11 + vfmadd231ps zmm23, zmm4, zmm12 + vfmadd231ps zmm28, zmm4, zmm13 + vbroadcastss zmm5, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm19, zmm5, zmm11 + vfmadd231ps zmm24, zmm5, zmm12 + vfmadd231ps zmm29, zmm5, zmm13 + vbroadcastss zmm6, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm10 + vfmadd231ps zmm20, zmm6, zmm11 + vfmadd231ps zmm25, zmm6, zmm12 + vfmadd231ps zmm30, zmm6, zmm13 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm7 + vmovups [r10 + 64], zmm16 + vmovups [r10 + 128], zmm21 + vmovups [r10 + 192], zmm26 + vmovups [r13], zmm8 + vmovups [r13 + 64], zmm17 + vmovups [r13 + 128], zmm22 + vmovups [r13 + 192], zmm27 + vmovups [rbx], zmm9 + vmovups [rbx + 64], zmm18 + vmovups [rbx + 128], zmm23 + vmovups [rbx + 192], zmm28 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm19 + vmovups [rbp + 128], zmm24 + vmovups [rbp + 192], zmm29 + vmovups [r8], zmm15 + vmovups [r8 + 64], zmm20 + vmovups [r8 + 128], zmm25 + vmovups [r8 + 192], zmm30 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + add r8, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm21 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm26 + vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm27 + vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm23 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm28 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm24 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm29 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r8 + 128]{k3}, zmm25 + vmovups ZMMWORD PTR [r8 + 192]{k4}, zmm30 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S new file mode 100644 index 000000000000..621a6a6c740f --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S @@ -0,0 +1,161 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x12, x11, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + add x19, x15, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + csel x12, x11, x12, LS + csel x19, x15, x19, LS + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with the biases. + ldp q11, q16, [x5, 0] + mov v12.16b, v11.16b + mov v13.16b, v11.16b + mov v14.16b, v11.16b + mov v15.16b, v11.16b + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldr d5, [x11, x20] + ldr d6, [x12, x20] + ldp q7, q8, [x5], 32 + fmla v11.4s, v7.4s, v2.s[0] + fmla v12.4s, v7.4s, v3.s[0] + fmla v13.4s, v7.4s, v4.s[0] + fmla v14.4s, v7.4s, v5.s[0] + fmla v15.4s, v7.4s, v6.s[0] + fmla v16.4s, v8.4s, v2.s[0] + fmla v17.4s, v8.4s, v3.s[0] + fmla v18.4s, v8.4s, v4.s[0] + fmla v19.4s, v8.4s, v5.s[0] + fmla v20.4s, v8.4s, v6.s[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + # Min/max clamping.. + fmin v11.4s, v1.4s, v11.4s + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmax v11.4s, v0.4s, v11.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q11, q16, [x6], 32 + stp q12, q17, [x13], 32 + stp q13, q18, [x14], 32 + stp q14, q19, [x15], 32 + stp q15, q20, [x19], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q11, [x6], 16 + str q12, [x13], 16 + str q13, [x14], 16 + str q14, [x15], 16 + str q15, [x19], 16 + mov v11.16b, v16.16b + mov v12.16b, v17.16b + mov v13.16b, v18.16b + mov v14.16b, v19.16b + mov v15.16b, v20.16b + + +tail_2: + tbz x1, 1, tail_1 + str d11, [x6], 8 + str d12, [x13], 8 + str d13, [x14], 8 + str d14, [x15], 8 + str d15, [x19], 8 + dup d11, v11.d[1] + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s11, [x6] + str s12, [x13] + str s13, [x14] + str s14, [x15] + str s15, [x19] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..48c1de312f3c --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,206 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..0a12f3e02af3 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,245 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm17, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm18, zmm17 + vmovaps zmm19, zmm17 + vmovaps zmm20, zmm17 + vmovaps zmm21, zmm17 + vmovaps zmm22, zmm17 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm17 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm18 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm19 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm20 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm21 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm22 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm22 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..7d83d0394bf2 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,229 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..57f6b2cd431a --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,274 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm18, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm19, zmm18 + vmovaps zmm20, zmm18 + vmovaps zmm21, zmm18 + vmovaps zmm22, zmm18 + vmovaps zmm23, zmm18 + vmovaps zmm24, zmm18 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm18 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm19 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm20 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm21 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm22 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm23 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm24 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm24 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..70355ca7cb7e --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,252 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..4a6a40971460 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,303 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm19, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm20, zmm19 + vmovaps zmm21, zmm19 + vmovaps zmm22, zmm19 + vmovaps zmm23, zmm19 + vmovaps zmm24, zmm19 + vmovaps zmm25, zmm19 + vmovaps zmm26, zmm19 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm19 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm20 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm21 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm22 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm23 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm24 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm25 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm26 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm26 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..b37845822da8 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,275 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + add r9, 64 + +inner_loop: + vmovaps zmm10, [r9 + 0] + add r9, 64 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm7 + vmovups [rax], zmm8 + vmovups [r15], zmm9 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 000000000000..beda0925e7ee --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,332 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with the biases. + vmovaps zmm7, [r9 + 0] + vmovaps zmm20, [r9 + 64] + vmovaps zmm8, zmm7 + vmovaps zmm9, zmm7 + vmovaps zmm14, zmm7 + vmovaps zmm15, zmm7 + vmovaps zmm16, zmm7 + vmovaps zmm17, zmm7 + vmovaps zmm18, zmm7 + vmovaps zmm19, zmm7 + vmovaps zmm21, zmm20 + vmovaps zmm22, zmm20 + vmovaps zmm23, zmm20 + vmovaps zmm24, zmm20 + vmovaps zmm25, zmm20 + vmovaps zmm26, zmm20 + vmovaps zmm27, zmm20 + vmovaps zmm28, zmm20 + add r9, 128 + +inner_loop: + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vbroadcastss zmm2, DWORD PTR [rsi + r11] + vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rax + r11] + vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r15 + r11] + vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm22, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm23, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm24, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm25, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm26, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm27, zmm2, zmm11 + vbroadcastss zmm2, DWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm28, zmm2, zmm11 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + # Min/max clamping.. + vminps zmm7, zmm1, zmm7 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vmaxps zmm7, zmm0, zmm7 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm7 + vmovups [rsi + 64], zmm20 + vmovups [rax], zmm8 + vmovups [rax + 64], zmm21 + vmovups [r15], zmm9 + vmovups [r15 + 64], zmm22 + vmovups [r14], zmm14 + vmovups [r14 + 64], zmm23 + vmovups [r12], zmm15 + vmovups [r12 + 64], zmm24 + vmovups [r10], zmm16 + vmovups [r10 + 64], zmm25 + vmovups [r13], zmm17 + vmovups [r13 + 64], zmm26 + vmovups [rbx], zmm18 + vmovups [rbx + 64], zmm27 + vmovups [rbp], zmm19 + vmovups [rbp + 64], zmm28 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm7 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm28 + +return: + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast \ No newline at end of file diff --git a/src/f32-raddextexp/f32-raddextexp.h b/src/f32-raddextexp/f32-raddextexp.h new file mode 100644 index 000000000000..539afddcbbee --- /dev/null +++ b/src/f32-raddextexp/f32-raddextexp.h @@ -0,0 +1,55 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#ifndef XNN_UKERNEL_WITH_PARAMS +#define XNN_UKERNEL_WITH_PARAMS(arch_flags, ukernel, element_tile, datatype, params_type, init_params) \ + XNN_UKERNEL(arch_flags, ukernel, element_tile, datatype) +#define XNN_DEFINED_UKERNEL_WITH_PARAMS +#endif + +#ifndef XNN_UKERNEL +#define XNN_UKERNEL(arch_flags, ukernel, element_tile, datatype) \ + XNN_UKERNEL_WITH_PARAMS(arch_flags, ukernel, element_tile, datatype, void, /*init_params=*/nullptr) +#define XNN_DEFINED_UKERNEL +#endif + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u64, 64, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2, 64, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4, 64, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u72, 72, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3, 72, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u80, 80, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2, 80, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5, 80, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u96, 96, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2, 96, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3, 96, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6, 96, float, struct xnn_f32_default_params, NULL) +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + +#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128, 128, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2, 128, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4, 128, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144, 144, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3, 144, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160, 160, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2, 160, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5, 160, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192, 192, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2, 192, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3, 192, float, struct xnn_f32_default_params, NULL) +XNN_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6, 192, float, struct xnn_f32_default_params, NULL) +#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) + +#ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS +#undef XNN_DEFINED_UKERNEL_WITH_PARAMS +#undef XNN_UKERNEL_WITH_PARAMS +#endif + +#ifdef XNN_DEFINED_UKERNEL +#undef XNN_DEFINED_UKERNEL +#undef XNN_UKERNEL +#endif diff --git a/src/indirection.c b/src/indirection.c index e8be4d5e7399..dd678e37a4c9 100644 --- a/src/indirection.c +++ b/src/indirection.c @@ -242,7 +242,7 @@ void xnn_indirection_init_dwconv2d_compressed( for (size_t output_x = 0; output_x < output_width; output_x++) { for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { const size_t input_x = output_x * stride_width + kernel_x * dilation_width - input_padding_left; - const size_t index = indirection_y * step_height + output_x * step_width * kernel_height + kernel_x * kernel_height + kernel_y; + const size_t index = indirection_y * step_height + (output_x * step_width + kernel_x) * kernel_height + kernel_y; if (input_x < input_width) { indirection_buffer[index] = (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride); @@ -254,7 +254,7 @@ void xnn_indirection_init_dwconv2d_compressed( } else { for (size_t output_x = 0; output_x < output_width; output_x++) { for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { - const size_t index = output_y * step_height + output_x * step_width * kernel_height + kernel_x * kernel_height + kernel_y; + const size_t index = output_y * step_height + (output_x * step_width + kernel_x) * kernel_height + kernel_y; indirection_buffer[index] = zero_buffer; } } @@ -270,7 +270,7 @@ void xnn_indirection_init_dwconv2d_compressed( for (size_t output_x = 0; output_x < output_width; output_x++) { for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { const size_t input_x = output_x * stride_width + kernel_x * dilation_width - input_padding_left; - const size_t index = (indirection_y) * step_height + output_x * step_width * kernel_height + kernel_x * kernel_height + kernel_y; + const size_t index = indirection_y * step_height + (output_x * step_width + kernel_x) * kernel_height + kernel_y; if (input_x < input_width) { indirection_buffer[index] = (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride); @@ -282,7 +282,7 @@ void xnn_indirection_init_dwconv2d_compressed( } else { for (size_t output_x = 0; output_x < output_width; output_x++) { for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { - const size_t index = (indirection_y) * step_height + output_x * step_width * kernel_height + kernel_x * kernel_height + kernel_y; + const size_t index = indirection_y * step_height + (output_x * step_width + kernel_x) * kernel_height + kernel_y; indirection_buffer[index] = zero_buffer; } } @@ -291,8 +291,8 @@ void xnn_indirection_init_dwconv2d_compressed( } if (output_y_end == output_height) { - const void* last_output_pixel = indirection_buffer[(indirection_y) * step_height - 1]; - const size_t last_kernel_index = (indirection_y) * step_height - (kernel_height * kernel_width); + const void* last_output_pixel = indirection_buffer[indirection_y * step_height - 1]; + const size_t last_kernel_index = indirection_y * step_height - kernel_height * kernel_width; for (size_t tile_index = kernel_height * kernel_width; tile_index < primary_tile; tile_index++) { indirection_buffer[last_kernel_index + tile_index] = last_output_pixel; } @@ -349,27 +349,24 @@ void xnn_indirection_init_dwconv2d( } void xnn_indirection_init_maxpool2d( - xnn_operator_t op, - size_t step_height, - size_t step_width, - uint32_t log2_element_size) + const void** indirection_buffer, + const void* input, + const size_t input_pixel_stride, + const size_t input_height, + const size_t input_width, + const size_t output_height, + const size_t output_width, + const size_t kernel_height, + const size_t kernel_width, + const size_t stride_height, + const size_t stride_width, + const size_t dilation_height, + const size_t dilation_width, + const size_t input_padding_top, + const size_t input_padding_left, + const size_t step_height, + const size_t step_width) { - const void** indirection_buffer = op->indirection_buffer; - const void* input = op->input; - const size_t input_pixel_stride = op->input_pixel_stride << log2_element_size; - const size_t input_height = op->input_height; - const size_t input_width = op->input_width; - const size_t output_height = op->output_height; - const size_t output_width = op->output_width; - const size_t pooling_height = op->kernel_height; - const size_t pooling_width = op->kernel_width; - const size_t stride_height = op->stride_height; - const size_t stride_width = op->stride_width; - const size_t dilation_height = op->dilation_height; - const size_t dilation_width = op->dilation_width; - const size_t input_padding_top = op->padding_top; - const size_t input_padding_left = op->padding_left; - const bool any_dilation = (dilation_height | dilation_width) > 1; if (any_dilation) { @@ -377,7 +374,7 @@ void xnn_indirection_init_maxpool2d( const size_t adjusted_padding_top = input_padding_top % dilation_height; const size_t adjusted_padding_left = input_padding_left % dilation_width; for (size_t output_y = 0; output_y < output_height; output_y++) { - for (size_t pooling_y = 0; pooling_y < pooling_height; pooling_y++) { + for (size_t pooling_y = 0; pooling_y < kernel_height; pooling_y++) { size_t safe_input_y = output_y * stride_height; if XNN_UNPREDICTABLE(safe_input_y < adjusted_padding_top) { safe_input_y += dilation_height; @@ -390,7 +387,7 @@ void xnn_indirection_init_maxpool2d( } for (size_t output_x = 0; output_x < output_width; output_x++) { - for (size_t pooling_x = 0; pooling_x < pooling_width; pooling_x++) { + for (size_t pooling_x = 0; pooling_x < kernel_width; pooling_x++) { size_t safe_input_x = output_x * stride_width; if XNN_UNPREDICTABLE(safe_input_x < adjusted_padding_left) { safe_input_x += dilation_width; @@ -402,7 +399,7 @@ void xnn_indirection_init_maxpool2d( input_x = safe_input_x; } - const size_t index = output_y * step_height + output_x * step_width * pooling_height + pooling_x * pooling_height + pooling_y; + const size_t index = output_y * step_height + output_x * step_width * kernel_height + pooling_x * kernel_height + pooling_y; indirection_buffer[index] = (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride); } } @@ -412,12 +409,12 @@ void xnn_indirection_init_maxpool2d( const size_t input_x_max = input_width - 1; const size_t input_y_max = input_height - 1; for (size_t output_y = 0; output_y < output_height; output_y++) { - for (size_t pooling_y = 0; pooling_y < pooling_height; pooling_y++) { + for (size_t pooling_y = 0; pooling_y < kernel_height; pooling_y++) { const size_t input_y = min(doz(output_y * stride_height + pooling_y * dilation_height, input_padding_top), input_y_max); for (size_t output_x = 0; output_x < output_width; output_x++) { - for (size_t pooling_x = 0; pooling_x < pooling_width; pooling_x++) { + for (size_t pooling_x = 0; pooling_x < kernel_width; pooling_x++) { const size_t input_x = min(doz(output_x * stride_width + pooling_x * dilation_width, input_padding_left), input_x_max); - const size_t index = output_y * step_height + output_x * step_width * pooling_height + pooling_x * pooling_height + pooling_y; + const size_t index = output_y * step_height + output_x * step_width * kernel_height + pooling_x * kernel_height + pooling_y; indirection_buffer[index] = (const void*) ((uintptr_t) input + (input_y * input_width + input_x) * input_pixel_stride); } } diff --git a/src/microkernel-utils.c b/src/microkernel-utils.c index bab390b88058..1dbf36268640 100644 --- a/src/microkernel-utils.c +++ b/src/microkernel-utils.c @@ -14,10 +14,17 @@ size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr, size_t nr, size_t num_threads) { size_t nc = n; if (num_threads > 1) { + const size_t min_num_tiles = num_threads * XNN_GEMM_TILES_PER_THREAD; const size_t num_tile_rows = divide_round_up(m, mr) * num_groups; - nc = min(max(1, (n * num_tile_rows) / - (nr * num_threads * XNN_GEMM_TILES_PER_THREAD)) * - nr, n); + const size_t num_tile_cols = divide_round_up(min_num_tiles, num_tile_rows); + + // We are looking for an `nc` that is the smallest integer multiple of `nr` + // such that `divide_round_up(n, nc)` is `num_tile_cols`. + nc = max(1, round_up(n, nr) / (nr * num_tile_cols)) * nr; + while (nr < nc && divide_round_up(n, nc - nr) == divide_round_up(n, nc)) { + nc -= nr; + } + nc = min(nc, n); } return nc; diff --git a/src/microparams-init.c b/src/microparams-init.c index 9cc2ac6d27ab..f8fdab1dd624 100644 --- a/src/microparams-init.c +++ b/src/microparams-init.c @@ -1218,8 +1218,7 @@ size_t xnn_init_qu8_reduce_scalar_params( size_t xnn_update_f32_reduce_scalar_params( struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)], - float scale, - int32_t num_elements) + float scale) { params->f32.scale = scale; return sizeof(params->f32); @@ -1227,21 +1226,17 @@ size_t xnn_update_f32_reduce_scalar_params( size_t xnn_update_qs8_reduce_scalar_params( struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)], - float scale, - int32_t num_elements) + float scale) { params->qs8.scale = params->qs8.input_output_scale * scale; - params->qs8.num_elements = num_elements; return sizeof(params->qs8); } size_t xnn_update_qu8_reduce_scalar_params( struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)], - float scale, - int32_t num_elements) + float scale) { params->qu8.scale = params->qs8.input_output_scale * scale; - params->qu8.num_elements = num_elements; return sizeof(params->qu8); } @@ -1256,16 +1251,6 @@ size_t xnn_init_f32_qu8_cvt_scalar_params( return sizeof(params->f32_qu8_cvt); } -size_t xnn_init_s32_f32_cvt_scalar_params( - union xnn_unary_uparams* params, - const union xnn_unary_params* op_params, - const struct xnn_quantization_params* input_quantization, - const struct xnn_quantization_params* output_quantization) -{ - params->s32_f32_cvt.scalar.zero_point = input_quantization->zero_point; - return sizeof(params->s32_f32_cvt); -} - size_t xnn_init_qs8_cvt_scalar_params( union xnn_unary_uparams* params, const union xnn_unary_params* op_params, diff --git a/src/operator-run.c b/src/operator-run.c index 4b4f4028e148..afcf6ebbfc2c 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -15,20 +15,23 @@ #include "xnnpack.h" #include "xnnpack/common.h" #include "xnnpack/compute.h" -#include "xnnpack/config-types.h" #include "xnnpack/indirection.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microkernel-type.h" #include "xnnpack/microparams.h" -#include "xnnpack/microparams-init.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" #include "xnnpack/packq.h" #include "xnnpack/quantization.h" #include "pthreadpool.h" +#if XNN_MAX_UARCH_TYPES > 1 +#include "xnnpack/config-types.h" +#include "xnnpack/microparams-init.h" +#endif // XNN_MAX_UARCH_TYPES > 1 + void xnn_compute_transposec_2d( const struct transpose_context* context, size_t i, @@ -502,6 +505,54 @@ void xnn_compute_dqgemm( (const void*) ((uintptr_t) &context->quantization_params[mr_block_start])); } +void xnn_compute_hmp_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t group_index, size_t mr_block_start, + size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { + const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( + mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); + const size_t cm_stride = context->cm_stride; + const size_t num_batch_dims = context->num_batch_dims; + + // Compute the group index offsets into A and B. + const size_t group_index_c = group_index; + size_t group_index_a = 0; + size_t group_index_b = 0; + for (int k = 0; k < num_batch_dims; k++) { + // Extract the kth batch index from the group_index. + const size_t index = group_index / context->batch_strides_c[k]; + group_index %= context->batch_strides_c[k]; + + // Compute the corresponding kth group index offsets into A and B. + group_index_a = (index % context->batch_dims_a[k]) + + context->batch_dims_a[k] * group_index_a; + group_index_b = (index % context->batch_dims_b[k]) + + context->batch_dims_b[k] * group_index_b; + } + + context->qp8_ukernel.function[uarch_index]( + mr_block_size, nr_block_size, context->k_scaled, + (const void*)((uintptr_t)context->a + group_index_a * context->ga_stride + + a_offset), + (const void*)((uintptr_t)context->packed_w + + group_index_b * context->gw_stride + + nr_block_start * context->w_stride), + (void*)((uintptr_t)context->c + group_index_c * context->gc_stride + + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize)), + cm_stride, + /*dst_stride_col=*/sizeof(float), context->fused_params); +} + +void xnn_compute_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size) { + xnn_compute_hmp_grouped_qp8gemm(context, XNN_UARCH_DEFAULT, group_index, + mr_block_start, nr_block_start, mr_block_size, + nr_block_size); +} + void xnn_compute_hmp_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, @@ -1958,26 +2009,6 @@ void xnn_compute_elementwise_binary_5d( context->ukernel(context->elements, a, b, y, &context->params); } -void xnn_compute_channel_shuffle_fixed( - const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t index) -{ - const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); - void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); - - context->fixed_ukernel(context->n, x, y); -} - -void xnn_compute_channel_shuffle_variable( - const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t index) -{ - const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride); - void* y = (void*) ((uintptr_t) context->y + index * context->y_stride); - - context->variable_ukernel(context->n, context->m, x, y); -} - void xnn_compute_lut_strided( const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index) @@ -2091,29 +2122,7 @@ void xnn_compute_contiguous_reduce( void* workspace_ptr = (void*) ((uintptr_t) context->workspace + workspace_offset); output_ptr = (void*) ((uintptr_t) context->output + output_offset); - if (context->s32_f32_cvt_ukernel) { - struct xnn_s32_f32_cvt_params s32_f32_cvt_params; - s32_f32_cvt_params.scalar.zero_point = context->params.qs8.num_elements * (int32_t) context->params.qs8.input_zero_point; - context->s32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params); - struct xnn_f32_qs8_cvt_params cvt_params; - cvt_params.scalar.scale = context->params.qs8.scale; - cvt_params.scalar.output_zero_point = context->params.qs8.output_zero_point; - context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - output_ptr, (union xnn_unary_uparams*) &cvt_params); - } else if (context->u32_f32_cvt_ukernel) { - struct xnn_s32_f32_cvt_params s32_f32_cvt_params; - s32_f32_cvt_params.scalar.zero_point = context->params.qu8.num_elements * (int32_t) context->params.qu8.input_zero_point; - context->u32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params); - struct xnn_f32_qu8_cvt_params cvt_params; - cvt_params.scalar.scale = context->params.qu8.scale; - cvt_params.scalar.output_zero_point = context->params.qu8.output_zero_point; - context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - output_ptr, (union xnn_unary_uparams*) &cvt_params); - } else { - context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, /*params=*/NULL); - } + context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, &context->cvt_params); } } @@ -2174,29 +2183,7 @@ void xnn_compute_discontiguous_reduce( void* workspace_ptr = (void*) ((uintptr_t) context->workspace + workspace_offset); output_ptr = (void*) ((uintptr_t) context->output + output_offset); - if (context->s32_f32_cvt_ukernel) { - struct xnn_s32_f32_cvt_params s32_f32_cvt_params; - s32_f32_cvt_params.scalar.zero_point = context->params.qs8.num_elements * (int32_t) context->params.qs8.input_zero_point; - context->s32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params); - struct xnn_f32_qs8_cvt_params cvt_params; - cvt_params.scalar.scale = context->params.qs8.scale; - cvt_params.scalar.output_zero_point = context->params.qs8.output_zero_point; - context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - output_ptr, (union xnn_unary_uparams*) &cvt_params); - } else if (context->u32_f32_cvt_ukernel) { - struct xnn_s32_f32_cvt_params s32_f32_cvt_params; - s32_f32_cvt_params.scalar.zero_point = context->params.qu8.num_elements * (int32_t) context->params.qu8.input_zero_point; - context->u32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params); - struct xnn_f32_qu8_cvt_params cvt_params; - cvt_params.scalar.scale = context->params.qu8.scale; - cvt_params.scalar.output_zero_point = context->params.qu8.output_zero_point; - context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, - output_ptr, (union xnn_unary_uparams*) &cvt_params); - } else { - context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, /*params=*/NULL); - } + context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, &context->cvt_params); } } @@ -2211,8 +2198,12 @@ void xnn_compute_pad_qd8_params( } } -void xnn_compute_f16_qd8_convert( +typedef struct xnn_qd8_quantization_params(f16_quantization_params_fn)(xnn_float16 min, xnn_float16 max, xnn_float16* f32_scale); +typedef struct xnn_qd8_quantization_params(f32_quantization_params_fn)(float min, float max, float* f32_scale); + +void xnn_compute_f16_qx8_convert( const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + f16_quantization_params_fn quantization_params_function, size_t batch_index) { const size_t x_stride = context->x_stride; @@ -2224,7 +2215,7 @@ void xnn_compute_f16_qd8_convert( xnn_float16 minmax[2]; context->rminmax_ukernel(n, input, minmax, &context->params); xnn_float16 f16_scale; - context->quantization_params[batch_index] = xnn_f16_qd8_asymmetric_quantization_params(minmax[0], minmax[1], &f16_scale); + context->quantization_params[batch_index] = quantization_params_function(minmax[0], minmax[1], &f16_scale); struct xnn_f16_qs8_cvt_params params; params.scalar.scale = f16_scale; @@ -2232,8 +2223,23 @@ void xnn_compute_f16_qd8_convert( context->convert_ukernel(n, input, output, (union xnn_unary_uparams*) ¶ms); } -void xnn_compute_f32_qd8_convert( +void xnn_compute_f16_qd8_convert( + const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index) +{ + return xnn_compute_f16_qx8_convert(context, xnn_f16_qd8_asymmetric_quantization_params, batch_index); +} + +void xnn_compute_f16_qdu8_convert( + const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index) +{ + return xnn_compute_f16_qx8_convert(context, xnn_f16_qdu8_asymmetric_quantization_params, batch_index); +} + +void xnn_compute_f32_qx8_convert( const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + f32_quantization_params_fn quantization_params_function, size_t batch_index) { const size_t x_stride = context->x_stride; @@ -2245,7 +2251,7 @@ void xnn_compute_f32_qd8_convert( float minmax[2]; context->rminmax_ukernel(n, input, minmax, &context->params); float scale; - context->quantization_params[batch_index] = xnn_f32_qd8_asymmetric_quantization_params(minmax[0], minmax[1], &scale); + context->quantization_params[batch_index] = quantization_params_function(minmax[0], minmax[1], &scale); struct xnn_f32_qs8_cvt_params params; params.scalar.scale = scale; @@ -2253,6 +2259,20 @@ void xnn_compute_f32_qd8_convert( context->convert_ukernel(n, input, output, (union xnn_unary_uparams*) ¶ms); } +void xnn_compute_f32_qd8_convert( + const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index) +{ + return xnn_compute_f32_qx8_convert(context, xnn_f32_qd8_asymmetric_quantization_params, batch_index); +} + +void xnn_compute_f32_qdu8_convert( + const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index) +{ + return xnn_compute_f32_qx8_convert(context, xnn_f32_qdu8_asymmetric_quantization_params, batch_index); +} + void xnn_compute_x32_pack_lh( const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)], size_t m_idx_start, size_t tile) { @@ -2268,15 +2288,17 @@ void xnn_compute_x32_pack_lh( void xnn_compute_f32_qp8_convert( const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t m_idx_start) { + size_t group_idx, size_t m_idx_start, size_t m_tile) { const float* lhs = (const float*)((const char*)context->lhs + - m_idx_start * context->lhs_stride); - int8_t* lhs_packed = - context->lhs_packed + - xnn_x8_packq_f32qp8_packed_offset(m_idx_start, context->k, context->mr, - context->kr, context->sr); - - context->packq_ukernel(/*m=*/1, context->k, context->mr, context->kr, + (group_idx * context->m + m_idx_start) * + context->lhs_stride); + int8_t* lhs_packed = (int8_t*)((uintptr_t)context->lhs_packed + + group_idx * context->group_stride + + xnn_x8_packq_f32qp8_packed_offset( + m_idx_start, context->k, context->mr, + context->kr, context->sr)); + + context->packq_ukernel(/*m=*/m_tile, context->k, context->mr, context->kr, context->sr, m_idx_start, lhs, context->lhs_stride, lhs_packed); } diff --git a/src/operators/argmax-pooling-nhwc.c b/src/operators/argmax-pooling-nhwc.c index f9a24e083201..e4ec1c55176b 100644 --- a/src/operators/argmax-pooling-nhwc.c +++ b/src/operators/argmax-pooling-nhwc.c @@ -373,7 +373,16 @@ enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32( const size_t step_width = pooling_width; const size_t step_height = pooling_size + (output_width - 1) * step_width * pooling_height; - xnn_indirection_init_maxpool2d(argmax_pooling_op, step_height, step_width, /*log2_element_size=*/XNN_LOG2_SIZEOF_FLOAT); + xnn_indirection_init_maxpool2d( + argmax_pooling_op->indirection_buffer, argmax_pooling_op->input, + argmax_pooling_op->input_pixel_stride << XNN_LOG2_SIZEOF_FLOAT, + argmax_pooling_op->input_height, argmax_pooling_op->input_width, + argmax_pooling_op->output_height, argmax_pooling_op->output_width, + argmax_pooling_op->kernel_height, argmax_pooling_op->kernel_width, + argmax_pooling_op->stride_height, argmax_pooling_op->stride_width, + argmax_pooling_op->dilation_height, argmax_pooling_op->dilation_width, + argmax_pooling_op->padding_top, argmax_pooling_op->padding_left, + step_height, step_width); argmax_pooling_op->context.argmax_pooling.indirect_input = argmax_pooling_op->indirection_buffer, diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index ccf49c4f8c5f..c19825035a69 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -26,6 +26,7 @@ #include "xnnpack/operator-utils.h" #include "xnnpack/operator.h" #include "xnnpack/pack.h" +#include "xnnpack/packq.h" #include "xnnpack/params.h" #include "pthreadpool.h" @@ -64,6 +65,7 @@ enum xnn_status create_batch_matrix_multiply_nc( batch_matrix_multiply_op->ukernel.type = xnn_microkernel_type_gemm; batch_matrix_multiply_op->ukernel.gemm = (struct xnn_ukernel_gemm) { .mr = mr, + .mr_packed = gemm_config->mr_packed, .nr = gemm_config->nr, .kr = UINT32_C(1) << gemm_config->log2_kr, .sr = UINT32_C(1) << gemm_config->log2_sr, @@ -144,26 +146,26 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights( batch_matrix_multiply_op->weights_cache, &cache_key); } + // Compute the shape and size of the packed data. + const uint32_t kr = batch_matrix_multiply_op->ukernel.gemm.kr; + const uint32_t sr = batch_matrix_multiply_op->ukernel.gemm.sr; + const size_t bias_element_size = sizeof(float); + const size_t k_stride = round_up_po2(k, kr * sr); + const size_t input_b_batch_stride = + bias_element_size + (k_stride << XNN_LOG2_SIZEOF_FLOAT); + batch_matrix_multiply_op->weights_stride = input_b_batch_stride; + // If the packed data has not been cached, pack and cache it. if (cache_offset == XNN_CACHE_NOT_FOUND) { - // Compute the shape and size of the packed data. const uint32_t nr = batch_matrix_multiply_op->ukernel.gemm.nr; - const uint32_t kr = batch_matrix_multiply_op->ukernel.gemm.kr; - const uint32_t sr = batch_matrix_multiply_op->ukernel.gemm.sr; - const size_t bias_element_size = sizeof(float); const size_t n_stride = round_up(n, nr); - const size_t k_stride = round_up_po2(k, kr * sr); - const size_t input_b_batch_stride = - (n_stride * bias_element_size + - ((n_stride * k_stride) << XNN_LOG2_SIZEOF_FLOAT)); - const size_t packed_size = batch_size_b * input_b_batch_stride; + const size_t packed_size = batch_size_b * n_stride * input_b_batch_stride; const size_t aligned_size = round_up_po2(packed_size, XNN_ALLOCATION_ALIGNMENT); // Allocate the packed weights. void* packed_data = xnn_get_pointer_to_write_weights( batch_matrix_multiply_op, aligned_size, /*padding_byte=*/0); - batch_matrix_multiply_op->weights_stride = input_b_batch_stride / n_stride; if (packed_data == NULL) { xnn_log_error( "failed to allocate %zu bytes for %s operator packed weights", @@ -240,7 +242,7 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f16( union xnn_f16_minmax_params params; if XNN_LIKELY(gemm_config->init.f16 != NULL) { - gemm_config->init.f16(¶ms, xnn_float16_from_float(-INFINITY), + gemm_config->init.f16(¶ms, xnn_float16_from_float(-INFINITY), xnn_float16_from_float(INFINITY)); } @@ -252,17 +254,16 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_f16( batch_matrix_multiply_op_out); } -enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( +enum xnn_status create_batch_matrix_multiply_nc_qx8_f32_qc8w( size_t batch_size_b, size_t k, size_t n, const int8_t* data_b, const float* scale_b, uint32_t flags, + const struct xnn_gemm_config *gemm_config, enum xnn_operator_type expected_operator_type, xnn_operator_t* batch_matrix_multiply_op_out) { - const struct xnn_gemm_config* gemm_config = - xnn_init_qd8_f32_qc8w_gemm_config(); if (gemm_config == NULL) { xnn_log_error( "failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string( - xnn_operator_type_batch_matrix_multiply_nc_qd8_f32_qc8w)); + expected_operator_type)); return xnn_status_unsupported_hardware; } @@ -279,7 +280,7 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( enum xnn_status status = create_batch_matrix_multiply_nc( flags, ¶ms, sizeof(params), gemm_config, gemm_ukernels, - xnn_operator_type_batch_matrix_multiply_nc_qd8_f32_qc8w, + expected_operator_type, batch_matrix_multiply_op_out); if (status != xnn_status_success) { return status; @@ -303,22 +304,22 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( batch_matrix_multiply_op->weights_cache, &cache_key); } + const uint32_t kr = batch_matrix_multiply_op->ukernel.gemm.kr; + const uint32_t sr = batch_matrix_multiply_op->ukernel.gemm.sr; + const size_t extra_bytes = 2 * sizeof(float); + const size_t k_stride = round_up_po2(k, kr * sr); + const size_t weights_stride = + gemm_config->packed_stride_weights_and_biases + ? gemm_config->packed_stride_weights_and_biases( + gemm_config, k, k_stride, extra_bytes) + : (k_stride << XNN_LOG2_SIZEOF_INT8_T) + extra_bytes + + sizeof(int32_t); + batch_matrix_multiply_op->weights_stride = weights_stride; + // If the packed data has not been cached, pack and cache it. if (cache_offset == XNN_CACHE_NOT_FOUND) { const uint32_t nr = batch_matrix_multiply_op->ukernel.gemm.nr; - const uint32_t kr = batch_matrix_multiply_op->ukernel.gemm.kr; - const uint32_t sr = batch_matrix_multiply_op->ukernel.gemm.sr; - const size_t extra_bytes = 2 * sizeof(float); - const size_t k_stride = round_up_po2(k, kr * sr); const size_t n_stride = round_up(n, nr); - const size_t weights_stride = - gemm_config->packed_stride_weights_and_biases - ? gemm_config->packed_stride_weights_and_biases( - gemm_config, k, k_stride, extra_bytes) - : (k_stride << XNN_LOG2_SIZEOF_INT8_T) + extra_bytes + - sizeof(int32_t); - assert(weights_stride == (k_stride << XNN_LOG2_SIZEOF_INT8_T) + - extra_bytes + sizeof(int32_t)); const size_t packed_size = batch_size_b * n_stride * weights_stride; const size_t aligned_size = round_up_po2(packed_size, XNN_ALLOCATION_ALIGNMENT); @@ -397,6 +398,42 @@ enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( return xnn_status_success; } +enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( + size_t batch_size_b, size_t k, size_t n, const int8_t* data_b, + const float* scale_b, uint32_t flags, + xnn_operator_t* batch_matrix_multiply_op_out) { + const struct xnn_gemm_config* gemm_config = + xnn_init_qd8_f32_qc8w_gemm_config(); + return create_batch_matrix_multiply_nc_qx8_f32_qc8w( + batch_size_b, k, n, data_b, scale_b, flags, gemm_config, + xnn_operator_type_batch_matrix_multiply_nc_qd8_f32_qc8w, + batch_matrix_multiply_op_out); +} + +enum xnn_status xnn_create_batch_matrix_multiply_nc_qp8_f32_qc8w( + size_t batch_size_b, size_t k, size_t n, const int8_t* data_b, + const float* scale_b, uint32_t flags, + xnn_operator_t* batch_matrix_multiply_op_out) { + const struct xnn_gemm_config* gemm_config = + xnn_init_qp8_f32_qc8w_gemm_config(); + return create_batch_matrix_multiply_nc_qx8_f32_qc8w( + batch_size_b, k, n, data_b, scale_b, flags, gemm_config, + xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w, + batch_matrix_multiply_op_out); +} + +enum xnn_status xnn_create_batch_matrix_multiply_nc_qdu8_f32_qc8w( + size_t batch_size_b, size_t k, size_t n, const int8_t* data_b, + const float* scale_b, uint32_t flags, + xnn_operator_t* batch_matrix_multiply_op_out) { + const struct xnn_gemm_config* gemm_config = + xnn_init_qdu8_f32_qc8w_gemm_config(); + return create_batch_matrix_multiply_nc_qx8_f32_qc8w( + batch_size_b, k, n, data_b, scale_b, flags, gemm_config, + xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w, + batch_matrix_multiply_op_out); +} + static enum xnn_status reshape_batch_matrix_multiply_nc( xnn_operator_t batch_matrix_multiply_op, enum xnn_operator_type expected_operator_type, size_t num_batch_dims, @@ -495,6 +532,9 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( mr = 1; } + const uint32_t mr_packed = + m > 1 ? batch_matrix_multiply_op->ukernel.gemm.mr_packed : 1; + assert(mr != 0 && mr <= XNN_MAX_MR); struct xnn_hmp_gemm_ukernel gemm_ukernel = gemm_cases[mr-1]; @@ -503,6 +543,8 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( switch (batch_matrix_multiply_op->type) { case xnn_operator_type_batch_matrix_multiply_nc_qd8_f32_qc8w: + case xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w: + case xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w: // Nothing to do here, the `B` matrix has already been packed. break; @@ -596,14 +638,28 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( XNN_UNREACHABLE; } - const size_t w_stride = - (round_up_po2(k, kr * sr) << log2_input_a_element_size) + - bias_element_size + w_stride_extra_bytes; + const bool is_qp8_ukernel = + (batch_matrix_multiply_op->type == + xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w); + const size_t k_scaled = k << log2_input_a_element_size; + const size_t a_stride = + is_qp8_ukernel ? xnn_x8_packq_f32qp8_packed_offset( + mr, k, mr, batch_matrix_multiply_op->ukernel.gemm.kr, + batch_matrix_multiply_op->ukernel.gemm.sr) + : k_scaled; + const size_t ga_stride = + is_qp8_ukernel ? xnn_x8_packq_f32qp8_packed_size(m, k, mr_packed, kr, sr) + : m * k_scaled; + const size_t w_stride = + is_qp8_ukernel ? batch_matrix_multiply_op->weights_stride + : (round_up_po2(k, kr * sr) << log2_input_a_element_size) + + bias_element_size + w_stride_extra_bytes; + batch_matrix_multiply_op->context.gemm.gemm.gemm = (struct gemm_context){ .k_scaled = k_scaled, - .a_stride = k_scaled, - .ga_stride = m * k_scaled, + .a_stride = a_stride, + .ga_stride = ga_stride, .w_stride = w_stride, .gw_stride = w_stride * round_up(n, nr), .cm_stride = n << log2_output_element_size, @@ -631,19 +687,35 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( size_t nc = xnn_gemm_best_nc(batch_size_c, m, n, mr, nr, num_threads); #if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { - gemm_compute->type = xnn_parallelization_type_3d_tile_2d_with_uarch; + if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { + gemm_compute->type = xnn_parallelization_type_3d_tile_2d_with_uarch; + if (is_qp8_ukernel) { + gemm_compute->task_3d_tile_2d_with_id = + (pthreadpool_task_3d_tile_2d_with_id_t) + xnn_compute_hmp_grouped_qp8gemm; + } else { gemm_compute->task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t)xnn_compute_hmp_grouped_gemm; + } + } else { + gemm_compute->type = xnn_parallelization_type_3d_tile_2d; + if (is_qp8_ukernel) { + gemm_compute->task_3d_tile_2d = + (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_qp8gemm; } else { - gemm_compute->type = xnn_parallelization_type_3d_tile_2d; gemm_compute->task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm; } - #else + } +#else gemm_compute->type = xnn_parallelization_type_3d_tile_2d; - gemm_compute->task_3d_tile_2d = - (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm; + if (is_qp8_ukernel) { + gemm_compute->task_3d_tile_2d = + (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_qp8gemm; + } else { + gemm_compute->task_3d_tile_2d = + (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm; + } #endif gemm_compute->range[0] = batch_size_c; gemm_compute->range[1] = m; @@ -712,6 +784,44 @@ enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qd8_f32_qc8w( pthreadpool_get_threads_count(threadpool)); } +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qp8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims, + const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k, + size_t n, pthreadpool_t threadpool) { + return reshape_batch_matrix_multiply_nc( + batch_matrix_multiply_op, + xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w, num_batch_dims, + batch_dims_a, batch_dims_b, m, k, n, /*workspace_size=*/NULL, + /*workspace_alignment=*/NULL, + /*log2_input_a_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*log2_input_b_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*bias_element_size=*/sizeof(int32_t), + /*w_stride_extra_bytes=*/2 * sizeof(float), + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &batch_matrix_multiply_op->params.f32_minmax, + sizeof(batch_matrix_multiply_op->params.f32_minmax), + pthreadpool_get_threads_count(threadpool)); +} + +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qdu8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims, + const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k, + size_t n, pthreadpool_t threadpool) { + return reshape_batch_matrix_multiply_nc( + batch_matrix_multiply_op, + xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w, num_batch_dims, + batch_dims_a, batch_dims_b, m, k, n, /*workspace_size=*/NULL, + /*workspace_alignment=*/NULL, + /*log2_input_a_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*log2_input_b_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*bias_element_size=*/sizeof(int32_t), + /*w_stride_extra_bytes=*/2 * sizeof(float), + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &batch_matrix_multiply_op->params.f32_minmax, + sizeof(batch_matrix_multiply_op->params.f32_minmax), + pthreadpool_get_threads_count(threadpool)); +} + static enum xnn_status setup_batch_matrix_multiply_nc( xnn_operator_t batch_matrix_multiply_op, enum xnn_operator_type expected_operator_type, const void* input_a, @@ -794,3 +904,24 @@ enum xnn_status xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w( quantization_params, /*input_b=*/NULL, packed_weights(batch_matrix_multiply_op), output); } + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_qp8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, const int8_t* input_a, + float* output) { + return setup_batch_matrix_multiply_nc( + batch_matrix_multiply_op, + xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w, input_a, + /*quantization_params=*/NULL, /*input_b=*/NULL, + packed_weights(batch_matrix_multiply_op), output); +} + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_qdu8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, const int8_t* input_a, + const struct xnn_quantization_params* quantization_params, + float* output) { + return setup_batch_matrix_multiply_nc( + batch_matrix_multiply_op, + xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w, input_a, + quantization_params, /*input_b=*/NULL, + packed_weights(batch_matrix_multiply_op), output); +} diff --git a/src/operators/channel-shuffle-nc.c b/src/operators/channel-shuffle-nc.c deleted file mode 100644 index 78ab65d8d707..000000000000 --- a/src/operators/channel-shuffle-nc.c +++ /dev/null @@ -1,311 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include - -#include "xnnpack.h" -#include "xnnpack/allocator.h" -#include "xnnpack/common.h" -#include "xnnpack/compute.h" -#include "xnnpack/config-types.h" -#include "xnnpack/config.h" -#include "xnnpack/log.h" -#include "xnnpack/operator-type.h" -#include "xnnpack/operator.h" -#include "xnnpack/params.h" -#include "pthreadpool.h" - -static enum xnn_status create_channel_shuffle_nc( - size_t groups, - size_t group_channels, - size_t input_stride, - size_t output_stride, - uint32_t flags, - const struct xnn_zip_config* zip_config, - enum xnn_operator_type operator_type, - xnn_operator_t* channel_shuffle_op_out) -{ - xnn_operator_t channel_shuffle_op = NULL; - enum xnn_status status = xnn_status_uninitialized; - - if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { - xnn_log_error("failed to create %s operator: XNNPACK is not initialized", - xnn_operator_type_to_string(operator_type)); - goto error; - } - - status = xnn_status_invalid_parameter; - - if (groups <= 1) { - xnn_log_error( - "failed to create %s operator with %zu groups: at least two groups required", - xnn_operator_type_to_string(operator_type), groups); - goto error; - } - - if (group_channels == 0) { - xnn_log_error( - "failed to create %s operator with %zu group channels: number of group channels must be non-zero", - xnn_operator_type_to_string(operator_type), group_channels); - goto error; - } - - const size_t channels = groups * group_channels; - if (input_stride < channels) { - xnn_log_error( - "failed to create %s operator with input element stride of %zu: " - "stride must be at least as large as the number of channels (%zux%zu)", - xnn_operator_type_to_string(operator_type), input_stride, groups, group_channels); - goto error; - } - - if (output_stride < channels) { - xnn_log_error( - "failed to create %s operator with output element stride of %zu: " - "stride must be at least as large as the number of channels (%zux%zu)", - xnn_operator_type_to_string(operator_type), output_stride, groups, group_channels); - goto error; - } - - status = xnn_status_out_of_memory; - - channel_shuffle_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator)); - if (channel_shuffle_op == NULL) { - xnn_log_error( - "failed to allocate %zu bytes for %s operator descriptor", - sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type)); - goto error; - } - - channel_shuffle_op->groups = groups; - channel_shuffle_op->group_channels = group_channels; - channel_shuffle_op->input_pixel_stride = input_stride; - channel_shuffle_op->output_pixel_stride = output_stride; - - channel_shuffle_op->type = operator_type; - channel_shuffle_op->flags = flags; - channel_shuffle_op->zip_config = zip_config; - - channel_shuffle_op->state = xnn_run_state_invalid; - - *channel_shuffle_op_out = channel_shuffle_op; - return xnn_status_success; - -error: - xnn_delete_operator(channel_shuffle_op); - return status; -} - - -enum xnn_status xnn_create_channel_shuffle_nc_x8( - size_t groups, - size_t group_channels, - size_t input_stride, - size_t output_stride, - uint32_t flags, - xnn_operator_t* channel_shuffle_op_out) -{ - const struct xnn_zip_config* zip_config = xnn_init_x8_zip_config(); - assert(zip_config != NULL); - return create_channel_shuffle_nc( - groups, - group_channels, - input_stride, - output_stride, - flags, - zip_config, - xnn_operator_type_channel_shuffle_nc_x8, - channel_shuffle_op_out); -} - -enum xnn_status xnn_create_channel_shuffle_nc_x32( - size_t groups, - size_t group_channels, - size_t input_stride, - size_t output_stride, - uint32_t flags, - xnn_operator_t* channel_shuffle_op_out) -{ - const struct xnn_zip_config* zip_config = xnn_init_x32_zip_config(); - if (zip_config == NULL) { - xnn_log_error( - "failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32)); - return xnn_status_unsupported_hardware; - } - return create_channel_shuffle_nc( - groups, - group_channels, - input_stride, - output_stride, - flags, - zip_config, - xnn_operator_type_channel_shuffle_nc_x32, - channel_shuffle_op_out); -} - -static enum xnn_status reshape_channel_shuffle_nc( - xnn_operator_t channel_shuffle_op, - size_t batch_size, - uint32_t log2_element_size, - const struct xnn_zip_config zip[restrict XNN_MIN_ELEMENTS(1)]) -{ - channel_shuffle_op->state = xnn_run_state_invalid; - - if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) { - xnn_log_error("failed to reshape %s operator: XNNPACK is not initialized", - xnn_operator_type_to_string(channel_shuffle_op->type)); - return xnn_status_uninitialized; - } - - if (batch_size == 0) { - channel_shuffle_op->state = xnn_run_state_skip; - return xnn_status_success; - } - - channel_shuffle_op->batch_size = batch_size; - - const size_t groups = channel_shuffle_op->groups; - channel_shuffle_op->context.channel_shuffle = (struct channel_shuffle_context) { - .x_stride = channel_shuffle_op->input_pixel_stride << log2_element_size, - .y_stride = channel_shuffle_op->output_pixel_stride << log2_element_size, - .n = channel_shuffle_op->group_channels << log2_element_size, - .m = groups, - }; - channel_shuffle_op->compute[0].type = xnn_parallelization_type_1d; - channel_shuffle_op->compute[0].range[0] = batch_size; - switch (groups) { - case 2: - channel_shuffle_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed; - channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x2; - break; - case 3: - channel_shuffle_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed; - channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x3; - break; - case 4: - channel_shuffle_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_fixed; - channel_shuffle_op->context.channel_shuffle.fixed_ukernel = zip->x4; - break; - default: - channel_shuffle_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_channel_shuffle_variable; - channel_shuffle_op->context.channel_shuffle.variable_ukernel = zip->xm; - break; - case 0: - case 1: - XNN_UNREACHABLE; - } - channel_shuffle_op->state = xnn_run_state_needs_setup; - - return xnn_status_success; -} - -enum xnn_status xnn_reshape_channel_shuffle_nc_x8( - xnn_operator_t channel_shuffle_op, - size_t batch_size, - pthreadpool_t threadpool) -{ - if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) { - xnn_log_error("failed to reshape operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8), - xnn_operator_type_to_string(channel_shuffle_op->type)); - return xnn_status_invalid_parameter; - } - - return reshape_channel_shuffle_nc( - channel_shuffle_op, - batch_size, - /*log2_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, - channel_shuffle_op->zip_config); -} - -enum xnn_status xnn_reshape_channel_shuffle_nc_x32( - xnn_operator_t channel_shuffle_op, - size_t batch_size, - pthreadpool_t threadpool) -{ - if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) { - xnn_log_error("failed to reshape operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32), - xnn_operator_type_to_string(channel_shuffle_op->type)); - return xnn_status_invalid_parameter; - } - - return reshape_channel_shuffle_nc( - channel_shuffle_op, - batch_size, - /*log2_element_size=*/XNN_LOG2_SIZEOF_UINT32_T, - channel_shuffle_op->zip_config); -} - -static enum xnn_status setup_channel_shuffle_nc( - xnn_operator_t channel_shuffle_op, - const void* input, - void* output) -{ - switch (channel_shuffle_op->state) { - case xnn_run_state_skip: - return xnn_status_success; - case xnn_run_state_invalid: - xnn_log_error( - "failed to setup %s operator: operator has not been reshaped yet", - xnn_operator_type_to_string(channel_shuffle_op->type)); - return xnn_status_invalid_state; - case xnn_run_state_needs_setup: - // Operator has been reshaped, but not setup, continue with setup. - case xnn_run_state_ready: - // Operator has been reshaped, and we are setting up with different pointers. - break; - } - - channel_shuffle_op->context.channel_shuffle.x = input; - channel_shuffle_op->context.channel_shuffle.y = output; - - channel_shuffle_op->state = xnn_run_state_ready; - - return xnn_status_success; -} - -enum xnn_status xnn_setup_channel_shuffle_nc_x8( - xnn_operator_t channel_shuffle_op, - const void* input, - void* output) -{ - if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x8) { - xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x8), - xnn_operator_type_to_string(channel_shuffle_op->type)); - return xnn_status_invalid_parameter; - } - - return setup_channel_shuffle_nc( - channel_shuffle_op, - input, - output); -} - -enum xnn_status xnn_setup_channel_shuffle_nc_x32( - xnn_operator_t channel_shuffle_op, - const void* input, - void* output) -{ - if (channel_shuffle_op->type != xnn_operator_type_channel_shuffle_nc_x32) { - xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_channel_shuffle_nc_x32), - xnn_operator_type_to_string(channel_shuffle_op->type)); - return xnn_status_invalid_parameter; - } - - return setup_channel_shuffle_nc( - channel_shuffle_op, - input, - output); -} diff --git a/src/operators/convolution-nchw.c b/src/operators/convolution-nchw.c index 6be0bf5a2b1e..b3880291ef9f 100644 --- a/src/operators/convolution-nchw.c +++ b/src/operators/convolution-nchw.c @@ -940,6 +940,54 @@ enum xnn_status xnn_create_convolution2d_nchw_f32( return status; } +enum xnn_status xnn_create_convolution2d_nchw_f32_f16( + uint32_t input_padding_top, uint32_t input_padding_right, + uint32_t input_padding_bottom, uint32_t input_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height, + uint32_t subsampling_width, uint32_t dilation_height, + uint32_t dilation_width, uint32_t groups, size_t group_input_channels, + size_t group_output_channels, size_t input_channel_stride, + size_t output_channel_stride, const void* kernel, const void* bias, + float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out) { + // Convert the `f16` kernel and bias to `f32` in temporary buffers. + const size_t num_kernel_entries = groups * group_input_channels * + group_output_channels * kernel_width * + kernel_height; + float* fp32_kernel_buffer = + (float*)xnn_allocate_memory(num_kernel_entries * sizeof(float)); + float* fp32_bias_buffer = NULL; + const xnn_float16* f16_kernel = (const xnn_float16*)kernel; + const xnn_float16* f16_bias = (const xnn_float16*)bias; + for (size_t i = 0; i < num_kernel_entries; ++i) { + fp32_kernel_buffer[i] = xnn_float16_to_float(f16_kernel[i]); + } + if (bias && !(flags & XNN_FLAG_FP32_STATIC_BIASES)) { + fp32_bias_buffer = (float*)xnn_allocate_memory( + groups * group_output_channels * sizeof(float)); + for (size_t i = 0; i < groups * group_output_channels; ++i) { + fp32_bias_buffer[i] = xnn_float16_to_float(f16_bias[i]); + } + bias = fp32_bias_buffer; + } + + // Delegate creation to the `f32` operator. + enum xnn_status status = xnn_create_convolution2d_nchw_f32( + input_padding_top, input_padding_right, input_padding_bottom, + input_padding_left, kernel_height, kernel_width, subsampling_height, + subsampling_width, dilation_height, dilation_width, groups, + group_input_channels, group_output_channels, input_channel_stride, + output_channel_stride, fp32_kernel_buffer, bias, output_min, output_max, + flags, code_cache, weights_cache, convolution_op_out); + + // Release temporary `f32` buffers. + xnn_release_memory(fp32_kernel_buffer); + xnn_release_memory(fp32_bias_buffer); + + return status; +} + static enum xnn_status reshape_convolution2d_nchw( xnn_operator_t convolution_op, enum xnn_operator_type expected_operator_type, diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 5269796d8965..bfc0be24eb70 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -773,7 +773,7 @@ static enum xnn_status create_convolution2d_nhwc( return status; } -enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w( +enum xnn_status create_convolution2d_nhwc_qx8_f16_qc8w( uint32_t input_padding_top, uint32_t input_padding_right, uint32_t input_padding_bottom, @@ -797,19 +797,21 @@ enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config *gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* convolution_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qd8_f16_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qd8_f16_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } const xnn_float16 fp16_output_min = xnn_float16_from_float(output_min); @@ -825,7 +827,6 @@ enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w( const struct xnn_qs8_packing_params packing_params = { .input_zero_point = 1, }; - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f16_qc8w_gemm_config(); if (gemm_config == NULL) { return xnn_status_unsupported_hardware; } @@ -869,13 +870,81 @@ enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w( /*vmulcaddc_config=*/NULL, /*linear_activation=*/false, /*relu_activation=*/false, - /*operator_type=*/xnn_operator_type_convolution_nhwc_qd8_f16_qc8w, + /*operator_type=*/expected_operator_type, /*dynamic_quantization=*/true, /*weights_cache=*/weights_cache, convolution_op_out); } -enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w( +enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f16_qc8w_gemm_config(); + return create_convolution2d_nhwc_qx8_f16_qc8w(input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, + kernel_height, kernel_width, subsampling_height, subsampling_width, dilation_height, dilation_width, + groups, group_input_channels, group_output_channels, input_channel_stride, output_channel_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, weights_cache, gemm_config, + xnn_operator_type_convolution_nhwc_qd8_f16_qc8w, convolution_op_out); +} + +enum xnn_status xnn_create_convolution2d_nhwc_qdu8_f16_qc8w( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f16_qc8w_gemm_config(); + return create_convolution2d_nhwc_qx8_f16_qc8w(input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, + kernel_height, kernel_width, subsampling_height, subsampling_width, dilation_height, dilation_width, + groups, group_input_channels, group_output_channels, input_channel_stride, output_channel_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, weights_cache, gemm_config, + xnn_operator_type_convolution_nhwc_qdu8_f16_qc8w, convolution_op_out); +} + +enum xnn_status create_convolution2d_nhwc_qx8_f32_qc8w( uint32_t input_padding_top, uint32_t input_padding_right, uint32_t input_padding_bottom, @@ -899,31 +968,32 @@ enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config *gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* convolution_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (output_min > output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound", - xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qd8_f32_qc8w), output_min, output_max); + xnn_operator_type_to_string(expected_operator_type), output_min, output_max); return xnn_status_invalid_parameter; } const struct xnn_qs8_packing_params packing_params = { .input_zero_point = 1, }; - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc8w_gemm_config(); assert(gemm_config != NULL); union xnn_f32_minmax_params gemm_params; @@ -965,12 +1035,86 @@ enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w( /*vmulcaddc_config=*/NULL, /*linear_activation=*/false, /*relu_activation=*/false, - /*operator_type=*/xnn_operator_type_convolution_nhwc_qd8_f32_qc8w, + /*operator_type=*/expected_operator_type, /*dynamic_quantization=*/true, /*weights_cache=*/weights_cache, convolution_op_out); } +enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc8w_gemm_config(); + return create_convolution2d_nhwc_qx8_f32_qc8w(input_padding_top, + input_padding_right, + input_padding_bottom, + input_padding_left, + kernel_height, kernel_width, subsampling_height, subsampling_width, dilation_height, + dilation_width, groups, group_input_channels, group_output_channels, input_channel_stride, output_channel_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, + weights_cache, gemm_config, xnn_operator_type_convolution_nhwc_qd8_f32_qc8w, convolution_op_out); +} + +enum xnn_status xnn_create_convolution2d_nhwc_qdu8_f32_qc8w( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_gemm_config(); + return create_convolution2d_nhwc_qx8_f32_qc8w(input_padding_top, + input_padding_right, + input_padding_bottom, + input_padding_left, + kernel_height, kernel_width, subsampling_height, subsampling_width, dilation_height, + dilation_width, groups, group_input_channels, group_output_channels, input_channel_stride, output_channel_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, + weights_cache, gemm_config, xnn_operator_type_convolution_nhwc_qdu8_f32_qc8w, convolution_op_out); +} + enum xnn_status xnn_create_convolution2d_nhwc_qu8( uint32_t input_padding_top, uint32_t input_padding_right, @@ -2589,7 +2733,7 @@ static enum xnn_status reshape_convolution2d_nhwc( } } -enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w( +enum xnn_status reshape_convolution2d_nhwc_qx8_f16_qc8w( xnn_operator_t convolution_op, size_t batch_size, size_t input_height, @@ -2598,6 +2742,7 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w( size_t* workspace_alignment, size_t* output_height_out, size_t* output_width_out, + enum xnn_operator_type expected_operator_type, pthreadpool_t threadpool) { convolution_op->last_input_height = convolution_op->input_height; @@ -2617,7 +2762,7 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w( } } return reshape_convolution2d_nhwc( - convolution_op, xnn_operator_type_convolution_nhwc_qd8_f16_qc8w, + convolution_op, expected_operator_type, batch_size, input_height, input_width, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2630,7 +2775,39 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w( threadpool); } -enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( +enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool) +{ + return reshape_convolution2d_nhwc_qx8_f16_qc8w(convolution_op, batch_size, input_height, input_width, workspace_size, + workspace_alignment, output_height_out, output_width_out, + xnn_operator_type_convolution_nhwc_qd8_f16_qc8w, threadpool); +} + +enum xnn_status xnn_reshape_convolution2d_nhwc_qdu8_f16_qc8w( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool) +{ + return reshape_convolution2d_nhwc_qx8_f16_qc8w(convolution_op, batch_size, input_height, input_width, workspace_size, + workspace_alignment, output_height_out, output_width_out, + xnn_operator_type_convolution_nhwc_qdu8_f16_qc8w, threadpool); +} + +enum xnn_status reshape_convolution2d_nhwc_qx8_f32_qc8w( xnn_operator_t convolution_op, size_t batch_size, size_t input_height, @@ -2639,6 +2816,7 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( size_t* workspace_alignment, size_t* output_height_out, size_t* output_width_out, + enum xnn_operator_type expected_operator_type, pthreadpool_t threadpool) { convolution_op->last_input_height = convolution_op->input_height; @@ -2658,7 +2836,7 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( } } return reshape_convolution2d_nhwc( - convolution_op, xnn_operator_type_convolution_nhwc_qd8_f32_qc8w, + convolution_op, expected_operator_type, batch_size, input_height, input_width, /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, @@ -2671,6 +2849,36 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( threadpool); } +enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool) +{ + return reshape_convolution2d_nhwc_qx8_f32_qc8w(convolution_op, batch_size, input_height, input_width, workspace_size, workspace_alignment, + output_height_out, output_width_out, xnn_operator_type_convolution_nhwc_qd8_f32_qc8w, threadpool); +} + +enum xnn_status xnn_reshape_convolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool) +{ + return reshape_convolution2d_nhwc_qx8_f32_qc8w(convolution_op, batch_size, input_height, input_width, workspace_size, workspace_alignment, + output_height_out, output_width_out, xnn_operator_type_convolution_nhwc_qdu8_f32_qc8w, threadpool); +} + enum xnn_status xnn_reshape_convolution2d_nhwc_qu8( xnn_operator_t convolution_op, size_t batch_size, @@ -2931,6 +3139,19 @@ enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f16_qc8w( /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T); } +enum xnn_status xnn_setup_convolution2d_nhwc_qdu8_f16_qc8w( + xnn_operator_t convolution_op, + void* workspace, + const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_convolution2d_nhwc( + convolution_op, xnn_operator_type_convolution_nhwc_qdu8_f16_qc8w, + workspace, input, output, quantization_params, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T); +} + enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w( xnn_operator_t convolution_op, void* workspace, @@ -2944,6 +3165,19 @@ enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w( /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T); } +enum xnn_status xnn_setup_convolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t convolution_op, + void* workspace, + const uint8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_convolution2d_nhwc( + convolution_op, xnn_operator_type_convolution_nhwc_qdu8_f32_qc8w, + workspace, input, output, quantization_params, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T); +} + enum xnn_status xnn_setup_convolution2d_nhwc_qu8( xnn_operator_t convolution_op, void* workspace, diff --git a/src/operators/deconvolution-nhwc.c b/src/operators/deconvolution-nhwc.c index cfa1f97e7531..21c3066ea217 100644 --- a/src/operators/deconvolution-nhwc.c +++ b/src/operators/deconvolution-nhwc.c @@ -896,7 +896,7 @@ enum xnn_status xnn_create_deconvolution2d_nhwc_f16( deconvolution_op_out); } -enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( +enum xnn_status create_deconvolution2d_nhwc_qx8_f32_qc8w( uint32_t output_padding_top, uint32_t output_padding_right, uint32_t output_padding_bottom, @@ -920,30 +920,31 @@ enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config * gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* deconvolution_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (output_min > output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound", - xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w), output_min, output_max); + xnn_operator_type_to_string(expected_operator_type), output_min, output_max); return xnn_status_invalid_parameter; } const struct xnn_qs8_packing_params packing_params = { .input_zero_point = 1, }; - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc8w_gemm_config(); assert(gemm_config != NULL); union xnn_f32_minmax_params params; @@ -972,7 +973,7 @@ enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( xnn_init_qs8_qc8w_scale_fp32_params, kernel_scale, ¶ms, sizeof(params), gemm_config, &gemm_config->minmax, - xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w, + expected_operator_type, /*dynamic_quantization=*/true, /*weights_cache=*/weights_cache, deconvolution_op_out); @@ -980,6 +981,86 @@ enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( return status; } +enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc8w_gemm_config(); + return create_deconvolution2d_nhwc_qx8_f32_qc8w(output_padding_top, output_padding_right, + output_padding_bottom, output_padding_left, + kernel_height, kernel_width, + stride_height, stride_width, + dilation_height, dilation_width, + groups, group_input_channels, group_output_channels, + input_pixel_stride, output_pixel_stride, + kernel_scale, kernel, bias, output_min, output_max, + flags, code_cache, weights_cache, + gemm_config, xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w, + deconvolution_op_out); +} + +enum xnn_status xnn_create_deconvolution2d_nhwc_qdu8_f32_qc8w( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_gemm_config(); + return create_deconvolution2d_nhwc_qx8_f32_qc8w(output_padding_top, output_padding_right, + output_padding_bottom, output_padding_left, + kernel_height, kernel_width, + stride_height, stride_width, + dilation_height, dilation_width, + groups, group_input_channels, group_output_channels, + input_pixel_stride, output_pixel_stride, + kernel_scale, kernel, bias, output_min, output_max, + flags, code_cache, weights_cache, + gemm_config, xnn_operator_type_deconvolution_nhwc_qdu8_f32_qc8w, + deconvolution_op_out); +} + enum xnn_status xnn_create_deconvolution2d_nhwc_f32( uint32_t output_padding_top, uint32_t output_padding_right, @@ -1831,7 +1912,7 @@ enum xnn_status xnn_reshape_deconvolution2d_nhwc_f16( threadpool); } -enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w( +enum xnn_status reshape_deconvolution2d_nhwc_qx8_f32_qc8w( xnn_operator_t deconvolution_op, size_t batch_size, size_t input_height, @@ -1840,11 +1921,12 @@ enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w( uint32_t adjustment_width, size_t* output_height_out, size_t* output_width_out, + enum xnn_operator_type expected_operator_type, pthreadpool_t threadpool) { - if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w) { + if (deconvolution_op->type != expected_operator_type) { xnn_log_error("failed to reshape operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w), + xnn_operator_type_to_string(expected_operator_type), xnn_operator_type_to_string(deconvolution_op->type)); return xnn_status_invalid_parameter; } @@ -1875,6 +1957,40 @@ enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w( threadpool); } +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool) +{ + return reshape_deconvolution2d_nhwc_qx8_f32_qc8w(deconvolution_op, batch_size, input_height, + input_width, adjustment_height, adjustment_width, output_height_out, + output_width_out, xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w, + threadpool); +} + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool) +{ + return reshape_deconvolution2d_nhwc_qx8_f32_qc8w(deconvolution_op, batch_size, input_height, + input_width, adjustment_height, adjustment_width, output_height_out, + output_width_out, xnn_operator_type_deconvolution_nhwc_qdu8_f32_qc8w, + threadpool); +} + enum xnn_status xnn_reshape_deconvolution2d_nhwc_f32( xnn_operator_t deconvolution_op, size_t batch_size, @@ -2052,6 +2168,15 @@ enum xnn_status xnn_setup_deconvolution2d_nhwc_qd8_f32_qc8w( return setup_deconvolution2d_nhwc(deconvolution_op, xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w, input, quantization_params, output); } +enum xnn_status xnn_setup_deconvolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t deconvolution_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_deconvolution2d_nhwc(deconvolution_op, xnn_operator_type_deconvolution_nhwc_qdu8_f32_qc8w, input, quantization_params, output); +} + enum xnn_status xnn_setup_deconvolution2d_nhwc_f32( xnn_operator_t deconvolution_op, const float* input, diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index c74784e279f5..03e2a75f7be2 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -418,7 +418,7 @@ enum xnn_status xnn_create_fully_connected_nc_f16( fully_connected_op_out); } -enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( +enum xnn_status create_fully_connected_nc_qx8_f16_qc4w( size_t input_channels, size_t output_channels, size_t input_stride, @@ -432,19 +432,21 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config *gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } @@ -455,7 +457,7 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( if (rounded_output_min >= rounded_output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w), rounded_output_min, rounded_output_max); + xnn_operator_type_to_string(expected_operator_type), rounded_output_min, rounded_output_max); return xnn_status_invalid_parameter; } @@ -463,14 +465,13 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( xnn_log_error( "failed to create %s operator with %" PRIu8 " kernel zero point: kernel zero point must be equals to 8 " "(unsigned weights) or 0 (signed weights)", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w), kernel_zero_point); + xnn_operator_type_to_string(expected_operator_type), kernel_zero_point); return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f16_qc4w_gemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_unsupported_hardware; } @@ -516,6 +517,52 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( fully_connected_op_out); } +enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f16_qc4w_gemm_config(); + return create_fully_connected_nc_qx8_f16_qc4w(input_channels, output_channels, input_stride, output_stride, + kernel_zero_point, kernel_scale, kernel, bias, output_min, + output_max, flags, code_cache, weights_cache, + gemm_config, xnn_operator_type_fully_connected_nc_qd8_f16_qc4w, fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f16_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f16_qc4w_gemm_config(); + return create_fully_connected_nc_qx8_f16_qc4w(input_channels, output_channels, input_stride, output_stride, + kernel_zero_point, kernel_scale, kernel, bias, output_min, + output_max, flags, code_cache, weights_cache, + gemm_config, xnn_operator_type_fully_connected_nc_qdu8_f16_qc4w, fully_connected_op_out); +} + enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( size_t input_channels, size_t output_channels, @@ -643,7 +690,7 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( fully_connected_op_out); } -enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w( +enum xnn_status create_fully_connected_nc_qx8_f32_qc4w( size_t input_channels, size_t output_channels, size_t input_stride, @@ -657,26 +704,28 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config *gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (output_min > output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc4w), output_min, output_max); + xnn_operator_type_to_string(expected_operator_type), output_min, output_max); return xnn_status_invalid_parameter; } @@ -684,11 +733,10 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w( xnn_log_error( "failed to create %s operator with %" PRIu8 " kernel zero point: kernel zero point must be equal to 8 " "(unsigned weights) or 0 (signed weights)", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc4w), kernel_zero_point); + xnn_operator_type_to_string(expected_operator_type), kernel_zero_point); return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc4w_gemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc4w)); @@ -732,23 +780,72 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w( /*kernel_scale_params=*/kernel_scale, ¶ms, sizeof(params), gemm_config, gemm_ukernels, - xnn_operator_type_fully_connected_nc_qd8_f32_qc4w, + expected_operator_type, /*weights_cache=*/weights_cache, fully_connected_op_out); } -enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( +enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc4w_gemm_config(); + return create_fully_connected_nc_qx8_f32_qc4w(input_channels, output_channels, + input_stride, output_stride, kernel_zero_point, + kernel_scale, kernel, bias, output_min, output_max, flags, + code_cache, weights_cache, gemm_config, xnn_operator_type_fully_connected_nc_qd8_f32_qc4w, + fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc4w_gemm_config(); + return create_fully_connected_nc_qx8_f32_qc4w(input_channels, output_channels, + input_stride, output_stride, kernel_zero_point, + kernel_scale, kernel, bias, output_min, output_max, flags, + code_cache, weights_cache, gemm_config, xnn_operator_type_fully_connected_nc_qdu8_f32_qc4w, + fully_connected_op_out); +} + +static enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qcxw( size_t input_channels, size_t output_channels, size_t input_stride, - size_t output_stride, uint8_t kernel_zero_point, const float* kernel_scale, - const void* kernel, const float* bias, float output_min, float output_max, - uint32_t flags, xnn_code_cache_t code_cache, - xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out) { + size_t output_stride, const float* kernel_scale, const void* kernel, + const float* bias, float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + enum xnn_operator_type operator_type, + const struct xnn_gemm_config* gemm_config, bool filter_is_nibble, + const void* packing_params, xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound " "must be non-NaN", - xnn_operator_type_to_string( - xnn_operator_type_fully_connected_nc_qp8_f32_qc4w)); + xnn_operator_type_to_string(operator_type)); return xnn_status_invalid_parameter; } @@ -756,8 +853,7 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound " "must be non-NaN", - xnn_operator_type_to_string( - xnn_operator_type_fully_connected_nc_qp8_f32_qc4w)); + xnn_operator_type_to_string(operator_type)); return xnn_status_invalid_parameter; } @@ -765,12 +861,52 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower " "bound must be less than or equal to upper bound", - xnn_operator_type_to_string( - xnn_operator_type_fully_connected_nc_qp8_f32_qc4w), - output_min, output_max); + xnn_operator_type_to_string(operator_type), output_min, output_max); return xnn_status_invalid_parameter; } + const struct gemm_fused_ukernels* gemm_ukernels = &gemm_config->minmax; + const bool linear_activation = + (output_max == INFINITY) && (output_min == -output_max); + if (linear_activation && gemm_config->linear.gemm[gemm_config->mr - 1] + .function[XNN_UARCH_DEFAULT] != NULL) { + gemm_ukernels = &gemm_config->linear; + } + + union xnn_f32_minmax_params params; + if XNN_LIKELY (gemm_config->init.f32 != NULL) { + gemm_config->init.f32(¶ms, output_min, output_max); + } + + return create_fully_connected_nc( + input_channels, output_channels, input_stride, output_stride, kernel, + /*bias=*/NULL, flags, + /*block_size=*/0, + /*extra_bl_bytes=*/0, + /*blockwise_kernel_scale_params=*/NULL, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/filter_is_nibble, + /*bias_element_size=*/sizeof(float), + (xnn_packw_gemm_gio_ukernel_fn)gemm_config->pack_gemm_gio, + (xnn_packw_gemm_goi_ukernel_fn)gemm_config->pack_gemm_goi, + /*pack_gemm_goi_bl_w=*/NULL, packing_params, + /*packed_weights_padding_byte=*/0, + /*extra_weights_bytes=*/0, + /*init_scale_params=*/NULL, + /*scale_params=*/bias, + /*init_kernel_scale_params=*/NULL, + /*kernel_scale_params=*/kernel_scale, ¶ms, sizeof(params), + gemm_config, gemm_ukernels, operator_type, + /*weights_cache=*/weights_cache, fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, uint8_t kernel_zero_point, const float* kernel_scale, + const void* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out) { if (kernel_zero_point != 8 && kernel_zero_point != 0) { xnn_log_error("failed to create %s operator with %" PRIu8 " kernel zero point: kernel zero point must be equals to 8 " @@ -791,6 +927,36 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( return xnn_status_unsupported_hardware; } + // We don't know input zero point until runtime, row sum is multiplied by it + // during packing, so set it to 1. + const struct xnn_qs8_qc4w_packing_params packing_params = { + /*input_zero_point=*/1, kernel_zero_point}; + + return xnn_create_fully_connected_nc_qp8_f32_qcxw( + input_channels, output_channels, input_stride, output_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, + weights_cache, + /*operator_type=*/xnn_operator_type_fully_connected_nc_qp8_f32_qc4w, + gemm_config, /*filter_is_nibble=*/true, &packing_params, + fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc8w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, const float* kernel_scale, const void* kernel, + const float* bias, float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) { + const struct xnn_gemm_config* gemm_config = + xnn_init_qp8_f32_qc8w_gemm_config(); + if (gemm_config == NULL) { + xnn_log_error( + "failed to create %s operator: unsupported hardware configuration", + xnn_operator_type_to_string( + xnn_operator_type_fully_connected_nc_qp8_f32_qc8w)); + return xnn_status_unsupported_hardware; + } + const struct gemm_fused_ukernels* gemm_ukernels = &gemm_config->minmax; const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max); @@ -799,39 +965,18 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( gemm_ukernels = &gemm_config->linear; } - union xnn_f32_minmax_params params; - if XNN_LIKELY (gemm_config->init.f32 != NULL) { - gemm_config->init.f32(¶ms, output_min, output_max); - } - // We don't know input zero point until runtime, row sum is multiplied by it // during packing, so set it to 1. - const struct xnn_qs8_qc4w_packing_params packing_params = { - /*input_zero_point=*/1, kernel_zero_point}; + const struct xnn_qs8_qc8w_packing_params packing_params = { + /*input_zero_point=*/1, 1.0f}; - return create_fully_connected_nc( - input_channels, output_channels, input_stride, output_stride, kernel, - /*bias=*/NULL, flags, - /*block_size=*/0, - /*extra_bl_bytes=*/0, - /*blockwise_kernel_scale_params=*/NULL, - /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, - /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, - /*filter_is_nibble=*/true, - /*bias_element_size=*/sizeof(float), - (xnn_packw_gemm_gio_ukernel_fn)gemm_config->pack_gemm_gio, - (xnn_packw_gemm_goi_ukernel_fn)gemm_config->pack_gemm_goi, - /*pack_gemm_goi_bl_w=*/NULL, - &packing_params, - /*packed_weights_padding_byte=*/0, - /*extra_weights_bytes=*/0, - /*init_scale_params=*/NULL, - /*scale_params=*/bias, - /*init_kernel_scale_params=*/NULL, - /*kernel_scale_params=*/kernel_scale, ¶ms, sizeof(params), - gemm_config, gemm_ukernels, - xnn_operator_type_fully_connected_nc_qp8_f32_qc4w, - /*weights_cache=*/weights_cache, fully_connected_op_out); + return xnn_create_fully_connected_nc_qp8_f32_qcxw( + input_channels, output_channels, input_stride, output_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, + weights_cache, + /*operator_type=*/xnn_operator_type_fully_connected_nc_qp8_f32_qc8w, + gemm_config, /*filter_is_nibble=*/false, &packing_params, + fully_connected_op_out); } enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( @@ -958,7 +1103,7 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qb4w( fully_connected_op_out); } -enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( +enum xnn_status create_fully_connected_nc_qx8_f32_qb4w( size_t input_channels, size_t output_channels, size_t input_stride, @@ -973,30 +1118,31 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config *gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (output_min > output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w), output_min, output_max); + xnn_operator_type_to_string(expected_operator_type), output_min, output_max); return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qb4w_gemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w)); @@ -1027,7 +1173,7 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( xnn_log_error( "failed to create %s operator with %" PRIu8 " kernel zero point: kernel zero point must be equal to 8 " "(unsigned weights) or 0 (signed weights)", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w), kernel_zero_point); + xnn_operator_type_to_string(expected_operator_type), kernel_zero_point); return xnn_status_invalid_parameter; } // Assuming kernel_scale.size() is output_channels * num_blocks. @@ -1080,7 +1226,55 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( fully_connected_op_out); } -enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w( +enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t block_size, + uint8_t kernel_zero_point, + const uint16_t* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qb4w_gemm_config(); + return create_fully_connected_nc_qx8_f32_qb4w(input_channels, output_channels, input_stride, output_stride, + block_size, kernel_zero_point, kernel_scale, kernel, bias, + output_min, output_max, flags, code_cache, weights_cache, + gemm_config, xnn_operator_type_fully_connected_nc_qd8_f32_qb4w, fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qb4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t block_size, + uint8_t kernel_zero_point, + const uint16_t* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qb4w_gemm_config(); + return create_fully_connected_nc_qx8_f32_qb4w(input_channels, output_channels, input_stride, output_stride, + block_size, kernel_zero_point, kernel_scale, kernel, bias, + output_min, output_max, flags, code_cache, weights_cache, + gemm_config, xnn_operator_type_fully_connected_nc_qdu8_f32_qb4w, fully_connected_op_out); +} + +enum xnn_status create_fully_connected_nc_qdx8_f32_qc8w( size_t input_channels, size_t output_channels, size_t input_stride, @@ -1093,33 +1287,34 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w( uint32_t flags, xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config* gemm_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (output_min > output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be less than or equal to upper bound", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc8w), output_min, output_max); + xnn_operator_type_to_string(expected_operator_type), output_min, output_max); return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc8w_gemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f32_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_unsupported_hardware; } @@ -1157,12 +1352,34 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w( xnn_init_qs8_qc8w_scale_fp32_params, kernel_scale, ¶ms, sizeof(params), gemm_config, gemm_ukernels, - xnn_operator_type_fully_connected_nc_qd8_f32_qc8w, + expected_operator_type, /*weights_cache=*/weights_cache, fully_connected_op_out); } -enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w( +enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f32_qc8w_gemm_config(); + return create_fully_connected_nc_qdx8_f32_qc8w(input_channels, output_channels, input_stride, output_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, + weights_cache, gemm_config, xnn_operator_type_fully_connected_nc_qd8_f32_qc8w, + fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qc8w( size_t input_channels, size_t output_channels, size_t input_stride, @@ -1176,18 +1393,42 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w( xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_gemm_config(); + return create_fully_connected_nc_qdx8_f32_qc8w(input_channels, output_channels, input_stride, output_stride, + kernel_scale, kernel, bias, output_min, output_max, flags, code_cache, + weights_cache, gemm_config, xnn_operator_type_fully_connected_nc_qdu8_f32_qc8w, + fully_connected_op_out); +} + +enum xnn_status create_fully_connected_nc_qx8_f16_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + const struct xnn_gemm_config *gemm_config, + enum xnn_operator_type expected_operator_type, + xnn_operator_t* fully_connected_op_out) { if (isnan(output_min)) { xnn_log_error( "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } if (isnan(output_max)) { xnn_log_error( "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_invalid_parameter; } @@ -1198,14 +1439,13 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w( if (rounded_output_min >= rounded_output_max) { xnn_log_error( "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc8w), rounded_output_min, rounded_output_max); + xnn_operator_type_to_string(expected_operator_type), rounded_output_min, rounded_output_max); return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f16_qc8w_gemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_qd8_f16_qc8w)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_unsupported_hardware; } @@ -1248,6 +1488,50 @@ enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w( fully_connected_op_out); } +enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qd8_f16_qc8w_gemm_config(); + return create_fully_connected_nc_qx8_f16_qc8w(input_channels, output_channels, input_stride, output_stride, kernel_scale, kernel, bias, + output_min, output_max, flags, code_cache, weights_cache, gemm_config, + xnn_operator_type_fully_connected_nc_qd8_f16_qc8w, + fully_connected_op_out); +} + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f16_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out) +{ + const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f16_qc8w_gemm_config(); + return create_fully_connected_nc_qx8_f16_qc8w(input_channels, output_channels, input_stride, output_stride, kernel_scale, kernel, bias, + output_min, output_max, flags, code_cache, weights_cache, gemm_config, + xnn_operator_type_fully_connected_nc_qdu8_f16_qc8w, + fully_connected_op_out); +} + enum xnn_status xnn_create_fully_connected_nc_f32_f16( size_t input_channels, size_t output_channels, @@ -1955,10 +2239,13 @@ static enum xnn_status reshape_fully_connected_nc( input_channels = round_up_po2(input_channels, planes); } - const bool is_qp8_ukernel = fully_connected_op->type == - xnn_operator_type_fully_connected_nc_qp8_f32_qc4w || - fully_connected_op->type == - xnn_operator_type_fully_connected_nc_qp8_f32_qb4w; + const bool is_qp8_ukernel = + (fully_connected_op->type == + xnn_operator_type_fully_connected_nc_qp8_f32_qc4w) || + (fully_connected_op->type == + xnn_operator_type_fully_connected_nc_qp8_f32_qc8w) || + (fully_connected_op->type == + xnn_operator_type_fully_connected_nc_qp8_f32_qb4w); fully_connected_op->context.gemm.gemm.gemm = (struct gemm_context){ .k_scaled = input_channels << log2_input_element_size, @@ -2128,6 +2415,25 @@ enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc4w( threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f16_qc4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) +{ + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f16_qc4w, + batch_size, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + // Pass 1 byte even though it is half byte, we handle the division via filter_is_nibble == true. + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/true, + /*dynamic_quantization=*/true, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_HALF, + &fully_connected_op->params.f32_qc4w_minmax, + sizeof(fully_connected_op->params.f32_qc4w_minmax), + threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qb4w( xnn_operator_t fully_connected_op, size_t batch_size, @@ -2166,6 +2472,25 @@ enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc4w( threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f32_qc4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) +{ + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f32_qc4w, + batch_size, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + // Pass 1 byte even though it is half byte, we handle the division via filter_is_nibble == true. + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/true, + /*dynamic_quantization=*/true, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &fully_connected_op->params.f32_qc4w_minmax, + sizeof(fully_connected_op->params.f32_qc4w_minmax), + threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qb4w( xnn_operator_t fully_connected_op, size_t batch_size, @@ -2185,6 +2510,25 @@ enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qb4w( threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f32_qb4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) +{ + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f32_qb4w, + batch_size, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + // Pass 1 byte even though it is half byte, we handle the division via filter_is_nibble == true. + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/true, + /*dynamic_quantization=*/true, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &fully_connected_op->params.f32_qb4w_minmax, + sizeof(fully_connected_op->params.f32_qb4w_minmax), + threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc8w( xnn_operator_t fully_connected_op, size_t batch_size, @@ -2203,6 +2547,24 @@ enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc8w( threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f16_qc8w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) +{ + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f16_qc8w, + batch_size, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*filter_is_nibble=*/false, + /*dynamic_quantization=*/true, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_HALF, + &fully_connected_op->params.f16_minmax, + sizeof(fully_connected_op->params.f16_minmax), + threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc8w( xnn_operator_t fully_connected_op, size_t batch_size, @@ -2221,6 +2583,24 @@ enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc8w( threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f32_qc8w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool) +{ + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f32_qc8w, + batch_size, + /*log2_input_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_INT8_T, + /*filter_is_nibble=*/false, + /*dynamic_quantization=*/true, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &fully_connected_op->params.f32_minmax, + sizeof(fully_connected_op->params.f32_minmax), + threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qc4w( xnn_operator_t fully_connected_op, size_t batch_size, pthreadpool_t threadpool) { @@ -2238,6 +2618,21 @@ enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qc4w( sizeof(fully_connected_op->params.f32_minmax), threadpool); } +enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qc8w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool) { + return reshape_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qc8w, + batch_size, + /*log2_input_element_size=*/0, + /*log2_filter_element_size=*/XNN_LOG2_SIZEOF_UINT8_T, + /*filter_is_nibble=*/false, + /*dynamic_quantization=*/false, + /*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT, + &fully_connected_op->params.f32_minmax, + sizeof(fully_connected_op->params.f32_minmax), threadpool); +} + enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qb4w( xnn_operator_t fully_connected_op, size_t batch_size, pthreadpool_t threadpool) { @@ -2413,6 +2808,17 @@ enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc4w( input, output, quantization_params); } +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f16_qc4w( + xnn_operator_t fully_connected_op, + const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f16_qc4w, + input, output, quantization_params); +} + enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qb4w( xnn_operator_t fully_connected_op, const int8_t* input, @@ -2435,6 +2841,17 @@ enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc4w( input, output, quantization_params); } +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f32_qc4w( + xnn_operator_t fully_connected_op, + const uint8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f32_qc4w, + input, output, quantization_params); +} + enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qb4w( xnn_operator_t fully_connected_op, const int8_t* input, @@ -2446,6 +2863,17 @@ enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qb4w( input, output, quantization_params); } +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f32_qb4w( + xnn_operator_t fully_connected_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f32_qb4w, + input, output, quantization_params); +} + enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc8w( xnn_operator_t fully_connected_op, const int8_t* input, @@ -2457,6 +2885,17 @@ enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc8w( input, output, quantization_params); } +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f16_qc8w( + xnn_operator_t fully_connected_op, + const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f16_qc8w, + input, output, quantization_params); +} + enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc4w( xnn_operator_t fully_connected_op, const int8_t* input, float* output) { return setup_fully_connected_nc( @@ -2464,6 +2903,13 @@ enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc4w( input, output, /*quantization_params=*/NULL); } +enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc8w( + xnn_operator_t fully_connected_op, const int8_t* input, float* output) { + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qp8_f32_qc8w, + input, output, /*quantization_params=*/NULL); +} + enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qb4w( xnn_operator_t fully_connected_op, const int8_t* input, float* output) { return setup_fully_connected_nc( @@ -2482,6 +2928,17 @@ enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc8w( input, output, quantization_params); } +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f32_qc8w( + xnn_operator_t fully_connected_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params) +{ + return setup_fully_connected_nc( + fully_connected_op, xnn_operator_type_fully_connected_nc_qdu8_f32_qc8w, + input, output, quantization_params); +} + enum xnn_status xnn_setup_fully_connected_nc_qs8( xnn_operator_t fully_connected_op, const int8_t* input, diff --git a/src/operators/max-pooling-nhwc.c b/src/operators/max-pooling-nhwc.c index e8dd7413d371..b9c536d14b36 100644 --- a/src/operators/max-pooling-nhwc.c +++ b/src/operators/max-pooling-nhwc.c @@ -501,7 +501,16 @@ static enum xnn_status reshape_max_pooling2d_nhwc( // Set a dummy input first, the actual input offset is calculated in setup when we have the input pointer. max_pooling_op->input = NULL; - xnn_indirection_init_maxpool2d(max_pooling_op, step_height, step_width, log2_input_element_size); + xnn_indirection_init_maxpool2d( + max_pooling_op->indirection_buffer, max_pooling_op->input, + max_pooling_op->input_pixel_stride << log2_input_element_size, + max_pooling_op->input_height, max_pooling_op->input_width, + max_pooling_op->output_height, max_pooling_op->output_width, + max_pooling_op->kernel_height, max_pooling_op->kernel_width, + max_pooling_op->stride_height, max_pooling_op->stride_width, + max_pooling_op->dilation_height, max_pooling_op->dilation_width, + max_pooling_op->padding_top, max_pooling_op->padding_left, + step_height, step_width); max_pooling_op->last_input = max_pooling_op->input; max_pooling_op->last_input_height = input_height; diff --git a/src/operators/reduce-nd.c b/src/operators/reduce-nd.c index ffff70a29e7c..865bfe831c6f 100644 --- a/src/operators/reduce-nd.c +++ b/src/operators/reduce-nd.c @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -16,6 +17,7 @@ #include "xnnpack/compute.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" +#include "xnnpack/reference-config.h" #include "xnnpack/datatype.h" #include "xnnpack/log.h" #include "xnnpack/microkernel-type.h" @@ -34,10 +36,10 @@ static enum xnn_status create_reduce_nd( const struct xnn_reduce_config* rdsum_config, const struct xnn_reduce_config* rsum_config, const struct xnn_unary_elementwise_config* cvt_config, - const struct xnn_unary_elementwise_config* s32_f32_cvt_config, - const struct xnn_unary_elementwise_config* u32_f32_cvt_config, const void* params, size_t params_size, + const void* cvt_params, + size_t cvt_params_size, xnn_operator_t* reduce_op_out) { xnn_operator_t reduce_op = NULL; @@ -64,13 +66,14 @@ static enum xnn_status create_reduce_nd( reduce_op->rdsum_config = rdsum_config; reduce_op->rsum_config = rsum_config; reduce_op->cvt_config = cvt_config; - reduce_op->s32_f32_cvt_config = s32_f32_cvt_config; - reduce_op->u32_f32_cvt_config = u32_f32_cvt_config; reduce_op->reduce.log2_data_element_size = log2_data_element_size; reduce_op->reduce.log2_accumulator_element_size = log2_accumulator_element_size; if (params_size != 0) { memcpy(&reduce_op->params, params, params_size); } + if (cvt_params_size != 0) { + memcpy(&reduce_op->params2, cvt_params, cvt_params_size); + } reduce_op->state = xnn_run_state_invalid; @@ -195,20 +198,22 @@ static enum xnn_status reshape_reduce_nd( if (workspace_alignment != NULL) { *workspace_alignment = XNN_ALLOCATION_ALIGNMENT; } + + size_t num_reduction_elements; if (normalized_reduction_axes[num_reduction_axes - 1] == num_input_dims - 1) { if (workspace_size != NULL) { const size_t num_output_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4]; *workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES; } - const size_t scale_dim = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5]; + num_reduction_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5]; const size_t axis_dim = normalized_input_shape[5]; if (reduce_op->rsum_config->update != NULL) { float scale = 1.0f; if (reduce_op->type == xnn_operator_type_mean_nd) { - scale = 1.0f / scale_dim; + scale = 1.0f / num_reduction_elements; } - reduce_op->rsum_config->update(&reduce_op->params.reduce, scale, scale_dim); + reduce_op->rsum_config->update(&reduce_op->params.reduce, scale); } reduce_op->context.reduce = (struct reduce_context) { @@ -217,7 +222,6 @@ static enum xnn_status reshape_reduce_nd( .accumulation_element_size = UINT32_C(1) << log2_accumulator_element_size, .output_element_size = UINT32_C(1) << log2_data_element_size, }; - memcpy(&reduce_op->context.reduce.params, &reduce_op->params.reduce, sizeof(reduce_op->params.reduce)); reduce_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_contiguous_reduce; reduce_op->compute[0].range[0] = normalized_input_shape[0]; @@ -229,13 +233,6 @@ static enum xnn_status reshape_reduce_nd( for (int i = XNN_MAX_TENSOR_DIMS / 2 - 2; i >= 0; --i) { reduce_op->context.reduce.output_stride[i] = (reduce_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i + 1) * 2]); } - - if (reduce_op->s32_f32_cvt_config) { - reduce_op->context.reduce.s32_f32_cvt_ukernel = reduce_op->s32_f32_cvt_config->ukernel; - } - if (reduce_op->u32_f32_cvt_config) { - reduce_op->context.reduce.u32_f32_cvt_ukernel = reduce_op->u32_f32_cvt_config->ukernel; - } } else { // Reduction along the non-innermost dimension const size_t channel_like_dim = normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1]; @@ -243,15 +240,15 @@ static enum xnn_status reshape_reduce_nd( const size_t num_output_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5]; *workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES; } - const size_t scale_dim = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4]; + num_reduction_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4]; const size_t axis_dim = normalized_input_shape[4]; if (reduce_op->rdsum_config->update != NULL) { float scale = 1.0f; if (reduce_op->type == xnn_operator_type_mean_nd) { - scale = 1.0f / scale_dim; + scale = 1.0f / num_reduction_elements; } - reduce_op->rdsum_config->update(&reduce_op->params.reduce, scale, scale_dim); + reduce_op->rdsum_config->update(&reduce_op->params.reduce, scale); } if (reduce_op->channels != channel_like_dim) { const size_t zero_size = (channel_like_dim << log2_data_element_size) + XNN_EXTRA_BYTES; @@ -274,7 +271,6 @@ static enum xnn_status reshape_reduce_nd( .accumulation_element_size = UINT32_C(1) << log2_accumulator_element_size, .output_element_size = UINT32_C(1) << log2_data_element_size, }; - memcpy(&reduce_op->context.reduce.params, &reduce_op->params.reduce, sizeof(reduce_op->params.reduce)); reduce_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_discontiguous_reduce; reduce_op->compute[0].range[0] = normalized_input_shape[1]; reduce_op->compute[0].range[1] = normalized_input_shape[3]; @@ -286,18 +282,35 @@ static enum xnn_status reshape_reduce_nd( reduce_op->context.reduce.output_stride[i] = (reduce_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i * 2+3)]); } } + memcpy(&reduce_op->context.reduce.params, &reduce_op->params.reduce, sizeof(reduce_op->params.reduce)); + memcpy(&reduce_op->context.reduce.cvt_params, &reduce_op->params2.unary, sizeof(reduce_op->params2.unary)); reduce_op->context.reduce.input_stride[XNN_MAX_TENSOR_DIMS - 1] = (1 << log2_data_element_size); if (reduce_op->cvt_config) { reduce_op->context.reduce.cvt_ukernel = reduce_op->cvt_config->ukernel; + // int32 is not actually a quantized type, so we need to include the input + // zero point (multiplied by the number of reduction elements) as part of + // the computation of the output zero point. + // The conversion normally looks like: + // + // y = (x - x_zero_point) * x_scale * inv_y_scale + y_zero_point + // + // Since this conversion ignores x_zero_point and x_scale, rewrite to: + // + // y = x * x_scale * inv_y_scale - x_zero_point * x_scale * inv_y_scale + y_zero_point + // + // Now we can say: + // + // inv_y_scale' = x_scale * inv_y_scale + // y_zero_point' = y_zero_point - x_zero_point * x_scale * inv_y_scale + reduce_op->context.reduce.cvt_params.reference.inv_y_scale = + reduce_op->context.reduce.params.qs8.scale; + reduce_op->context.reduce.cvt_params.reference.y_zero_point -= + ((int32_t) num_reduction_elements * + reduce_op->context.reduce.cvt_params.reference.x_zero_point) * + reduce_op->context.reduce.cvt_params.reference.inv_y_scale; } - if (reduce_op->s32_f32_cvt_config) { - reduce_op->context.reduce.s32_f32_cvt_ukernel = reduce_op->s32_f32_cvt_config->ukernel; - } - if (reduce_op->u32_f32_cvt_config) { - reduce_op->context.reduce.u32_f32_cvt_ukernel = reduce_op->u32_f32_cvt_config->ukernel; - } - for (int i = XNN_MAX_TENSOR_DIMS - 2; i >= 0; --i) { - reduce_op->context.reduce.input_stride[i] = (reduce_op->context.reduce.input_stride[i + 1] * normalized_input_shape[i + 1]); + for (int i = XNN_MAX_TENSOR_DIMS - 2; i >= 0; --i) { + reduce_op->context.reduce.input_stride[i] = (reduce_op->context.reduce.input_stride[i + 1] * normalized_input_shape[i + 1]); } memcpy(reduce_op->context.reduce.input_shape, normalized_input_shape, XNN_MAX_TENSOR_DIMS * sizeof(size_t)); reduce_op->state = xnn_run_state_needs_setup; @@ -363,8 +376,6 @@ enum xnn_status xnn_create_reduce_nd( const struct xnn_reduce_config* rsum_config = NULL; const struct xnn_reduce_config* rdsum_config = NULL; const struct xnn_unary_elementwise_config* cvt_config = NULL; - const struct xnn_unary_elementwise_config* s32_f32_cvt_config = NULL; - const struct xnn_unary_elementwise_config* u32_f32_cvt_config = NULL; uint32_t log2_data_element_size = xnn_datatype_log2_size_bytes(datatype); uint32_t log2_accumulator_element_size; switch(datatype) { @@ -373,8 +384,6 @@ enum xnn_status xnn_create_reduce_nd( rsum_config = xnn_init_f16_f32acc_rsum_config(); rdsum_config = xnn_init_f16_f32acc_rdsum_config(); cvt_config = xnn_init_f32_to_f16_cvt_config(); - s32_f32_cvt_config = unused; - u32_f32_cvt_config = unused; break; } case xnn_datatype_fp32: { @@ -382,28 +391,22 @@ enum xnn_status xnn_create_reduce_nd( rsum_config = xnn_init_f32_rsum_config(); rdsum_config = xnn_init_f32_rdsum_config(); cvt_config = unused; - s32_f32_cvt_config = unused; - u32_f32_cvt_config = unused; break; } case xnn_datatype_qint8: { // qs8 log2_accumulator_element_size = 2; rsum_config = xnn_init_qs8_rsum_config(); rdsum_config = xnn_init_qs8_rdsum_config(); - cvt_config = xnn_init_f32_to_qs8_cvt_config(); - s32_f32_cvt_config = xnn_init_s32_to_f32_cvt_config(); - u32_f32_cvt_config = unused; + cvt_config = xnn_init_unary_reference_config(xnn_unary_convert, xnn_datatype_int32, xnn_datatype_qint8); break; } case xnn_datatype_quint8: { // qu8 log2_accumulator_element_size = 2; rsum_config = xnn_init_qu8_rsum_config(); rdsum_config = xnn_init_qu8_rdsum_config(); - cvt_config = xnn_init_f32_to_qu8_cvt_config(); - s32_f32_cvt_config = unused; - // We just use an int32 -> f32 conversion. This means we effectively only + // We just use an int32 -> qu8 conversion. This means we effectively only // have a 31-bit accumulator instead of 32-bit, but that seems insignificant. - u32_f32_cvt_config = xnn_init_s32_to_f32_cvt_config(); + cvt_config = xnn_init_unary_reference_config(xnn_unary_convert, xnn_datatype_int32, xnn_datatype_quint8); break; } default: @@ -413,16 +416,13 @@ enum xnn_status xnn_create_reduce_nd( }; // Check configs and restore unused pointers to NULL. - if (rdsum_config == NULL || rsum_config == NULL || cvt_config == NULL || - s32_f32_cvt_config == NULL || u32_f32_cvt_config == NULL) { + if (rdsum_config == NULL || rsum_config == NULL || cvt_config == NULL) { xnn_log_error( "failed to create %s (%s) operator: unsupported hardware configuration", xnn_operator_type_to_string(operator_type), xnn_datatype_to_string(datatype)); return xnn_status_unsupported_hardware; } else { cvt_config = cvt_config == unused ? NULL : cvt_config; - s32_f32_cvt_config = s32_f32_cvt_config == unused ? NULL : s32_f32_cvt_config; - u32_f32_cvt_config = u32_f32_cvt_config == unused ? NULL : u32_f32_cvt_config; } struct xnn_reduce_params params; @@ -431,11 +431,15 @@ enum xnn_status xnn_create_reduce_nd( if (rsum_config->init) { params_size = rsum_config->init(¶ms, input_quantization, output_quantization); } + union xnn_unary_uparams cvt_params; + size_t cvt_params_size = 0; + if (cvt_config && cvt_config->init) { + cvt_params_size = cvt_config->init(&cvt_params, NULL, input_quantization, output_quantization); + } return create_reduce_nd( flags, log2_data_element_size, log2_accumulator_element_size, operator_type, - rdsum_config, rsum_config, cvt_config, s32_f32_cvt_config, - u32_f32_cvt_config, ¶ms, params_size, reduce_op_out); + rdsum_config, rsum_config, cvt_config, ¶ms, params_size, &cvt_params, cvt_params_size, reduce_op_out); } enum xnn_status xnn_reshape_reduce_nd( diff --git a/src/operators/unary-elementwise-nc.c b/src/operators/unary-elementwise-nc.c index ba2f2f7e781f..3954085192d9 100644 --- a/src/operators/unary-elementwise-nc.c +++ b/src/operators/unary-elementwise-nc.c @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. #include -#include #include #include #include @@ -17,14 +16,14 @@ #include "xnnpack/config-types.h" #include "xnnpack/config.h" #include "xnnpack/datatype.h" +#include "xnnpack/internal.h" #include "xnnpack/log.h" -#include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" -#include "xnnpack/node-type.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator-utils.h" #include "xnnpack/operator.h" +#include "xnnpack/packq.h" #include "xnnpack/params.h" #include "xnnpack/reference-config.h" #include "pthreadpool.h" @@ -90,6 +89,8 @@ static const struct xnn_unary_elementwise_config* get_config( return xnn_init_f32_to_qs8_cvt_config(); } else if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_quint8) { return xnn_init_f32_to_qu8_cvt_config(); + } else if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_qpint8) { + return xnn_init_f32_to_qp8_cvt_config(); } else if (input_datatype == xnn_datatype_fp16 && output_datatype == xnn_datatype_fp32) { return xnn_init_f16_to_f32_cvt_config(); } else if (input_datatype == xnn_datatype_fp16 && output_datatype == xnn_datatype_qint8) { @@ -711,56 +712,85 @@ static enum xnn_status setup_unary_elementwise_nc( return xnn_status_success; } -enum xnn_status xnn_create_convert_nc_f16_qd8( +enum xnn_status create_convert_nc_f16_qx8( uint32_t flags, + const struct xnn_unary_elementwise_config* cvt_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* convert_op_out) { const struct xnn_reduce_config* f16_rminmax_config = xnn_init_f16_rminmax_config(); if (f16_rminmax_config == NULL) { xnn_log_error( "failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_convert_nc_f16_qd8)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_unsupported_hardware; } struct xnn_f16_default_params params; enum xnn_status status = create_unary_elementwise_nc( - flags, xnn_init_f16_to_qs8_cvt_config(), + flags, cvt_config, ¶ms, sizeof(params), - xnn_operator_type_convert_nc_f16_qd8, convert_op_out); + expected_operator_type, convert_op_out); if (status == xnn_status_success) { (*convert_op_out)->rminmax_config = f16_rminmax_config; } return status; } -enum xnn_status xnn_create_convert_nc_f32_qd8( +enum xnn_status create_convert_nc_f32_qx8( uint32_t flags, + const struct xnn_unary_elementwise_config* cvt_config, + enum xnn_operator_type expected_operator_type, xnn_operator_t* convert_op_out) { const struct xnn_reduce_config* f32_rminmax_config = xnn_init_f32_rminmax_config(); if (f32_rminmax_config == NULL) { xnn_log_error( "failed to create %s operator: unsupported hardware configuration", - xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qd8)); + xnn_operator_type_to_string(expected_operator_type)); return xnn_status_unsupported_hardware; } struct xnn_f32_default_params params; enum xnn_status status = create_unary_elementwise_nc( - flags, xnn_init_f32_to_qs8_cvt_config(), + flags, cvt_config, ¶ms, sizeof(params), - xnn_operator_type_convert_nc_f32_qd8, convert_op_out); + expected_operator_type, convert_op_out); if (status == xnn_status_success) { (*convert_op_out)->rminmax_config = f32_rminmax_config; } return status; } -enum xnn_status xnn_create_convert_nc_f32_qp8(uint32_t flags, - xnn_operator_t* convert_op_out) { +enum xnn_status xnn_create_convert_nc_f16_qd8( + uint32_t flags, + xnn_operator_t* convert_op_out) { + return create_convert_nc_f16_qx8(flags, xnn_init_f16_to_qs8_cvt_config(), xnn_operator_type_convert_nc_f16_qd8, convert_op_out); +} + +enum xnn_status xnn_create_convert_nc_f16_qdu8( + uint32_t flags, + xnn_operator_t* convert_op_out) { + return create_convert_nc_f16_qx8(flags, xnn_init_f16_to_qu8_cvt_config(), xnn_operator_type_convert_nc_f16_qdu8, convert_op_out); +} + +enum xnn_status xnn_create_convert_nc_f32_qd8( + uint32_t flags, + xnn_operator_t* convert_op_out) { + return create_convert_nc_f32_qx8(flags, xnn_init_f32_to_qs8_cvt_config(), xnn_operator_type_convert_nc_f32_qd8, convert_op_out); +} + +enum xnn_status xnn_create_convert_nc_f32_qdu8( + uint32_t flags, + xnn_operator_t* convert_op_out) { + return create_convert_nc_f32_qx8(flags, xnn_init_f32_to_qu8_cvt_config(), xnn_operator_type_convert_nc_f32_qdu8, convert_op_out); +} + +enum xnn_status xnn_create_convert_nc_f32_qp8( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + xnn_operator_t* convert_op_out) { const struct xnn_reduce_config* f32_rminmax_config = xnn_init_f32_rminmax_config(); if (f32_rminmax_config == NULL) { @@ -778,6 +808,7 @@ enum xnn_status xnn_create_convert_nc_f32_qp8(uint32_t flags, xnn_operator_type_convert_nc_f32_qp8, convert_op_out); if (status == xnn_status_success) { (*convert_op_out)->rminmax_config = f32_rminmax_config; + (*convert_op_out)->gemm_config = gemm_config; } return status; } @@ -812,17 +843,18 @@ enum xnn_status xnn_create_copy_nc_x32( xnn_operator_type_copy_nc_x32, copy_op_out); } -enum xnn_status xnn_reshape_convert_nc_f16_qd8( +enum xnn_status reshape_convert_nc_f16_qx8( xnn_operator_t convert_op, size_t batch_size, size_t channels, size_t input_stride, size_t output_stride, + enum xnn_operator_type expected_type, pthreadpool_t threadpool) { - if (convert_op->type != xnn_operator_type_convert_nc_f16_qd8) { + if (convert_op->type != expected_type) { xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_convert_nc_f16_qd8), + xnn_operator_type_to_string(expected_type), xnn_operator_type_to_string(convert_op->type)); return xnn_status_invalid_parameter; } @@ -848,6 +880,16 @@ enum xnn_status xnn_reshape_convert_nc_f16_qd8( convert_op->compute[0].type = xnn_parallelization_type_1d; convert_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_f16_qd8_convert; + switch (expected_type) { + case xnn_operator_type_convert_nc_f16_qd8: + convert_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_f16_qd8_convert; + break; + case xnn_operator_type_convert_nc_f16_qdu8: + convert_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_f16_qdu8_convert; + break; + default: + XNN_UNREACHABLE; + } convert_op->compute[0].range[0] = batch_size; convert_op->compute[1].type = xnn_parallelization_type_1d; @@ -859,17 +901,18 @@ enum xnn_status xnn_reshape_convert_nc_f16_qd8( return xnn_status_success; } -enum xnn_status xnn_reshape_convert_nc_f32_qd8( +enum xnn_status reshape_convert_nc_f32_qx8( xnn_operator_t convert_op, size_t batch_size, size_t channels, size_t input_stride, size_t output_stride, + enum xnn_operator_type expected_type, pthreadpool_t threadpool) { - if (convert_op->type != xnn_operator_type_convert_nc_f32_qd8) { + if (convert_op->type != expected_type) { xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qd8), + xnn_operator_type_to_string(expected_type), xnn_operator_type_to_string(convert_op->type)); return xnn_status_invalid_parameter; } @@ -894,7 +937,16 @@ enum xnn_status xnn_reshape_convert_nc_f32_qd8( memcpy(&convert_op->context.f32_qd8_convert.params, &convert_op->params.f32_default, sizeof(convert_op->params.f32_default)); convert_op->compute[0].type = xnn_parallelization_type_1d; - convert_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_f32_qd8_convert; + switch (expected_type) { + case xnn_operator_type_convert_nc_f32_qd8: + convert_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_f32_qd8_convert; + break; + case xnn_operator_type_convert_nc_f32_qdu8: + convert_op->compute[0].task_1d = (pthreadpool_task_1d_t) xnn_compute_f32_qdu8_convert; + break; + default: + XNN_UNREACHABLE; + } convert_op->compute[0].range[0] = batch_size; convert_op->compute[1].type = xnn_parallelization_type_1d; @@ -906,10 +958,55 @@ enum xnn_status xnn_reshape_convert_nc_f32_qd8( return xnn_status_success; } -enum xnn_status xnn_reshape_convert_nc_f32_qp8(xnn_operator_t convert_op, - size_t batch_size, - size_t channels, - size_t input_stride, +enum xnn_status xnn_reshape_convert_nc_f16_qd8( + xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool) +{ + return reshape_convert_nc_f16_qx8(convert_op, batch_size, channels, input_stride, output_stride, xnn_operator_type_convert_nc_f16_qd8, threadpool); +} + +enum xnn_status xnn_reshape_convert_nc_f16_qdu8( + xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool) +{ + return reshape_convert_nc_f16_qx8(convert_op, batch_size, channels, input_stride, output_stride, xnn_operator_type_convert_nc_f16_qdu8, threadpool); +} + +enum xnn_status xnn_reshape_convert_nc_f32_qd8( + xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool) +{ + return reshape_convert_nc_f32_qx8(convert_op, batch_size, channels, input_stride, output_stride, xnn_operator_type_convert_nc_f32_qd8, threadpool); +} + +enum xnn_status xnn_reshape_convert_nc_f32_qdu8( + xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool) +{ + return reshape_convert_nc_f32_qx8(convert_op, batch_size, channels, input_stride, output_stride, xnn_operator_type_convert_nc_f32_qdu8, threadpool); +} + +enum xnn_status xnn_reshape_convert_nc_f32_qp8(xnn_operator_t convert_op, // + size_t num_groups, // + size_t batch_size, // + size_t channels, // + size_t input_stride, // pthreadpool_t threadpool) { if (convert_op->type != xnn_operator_type_convert_nc_f32_qp8) { xnn_log_error( @@ -928,9 +1025,15 @@ enum xnn_status xnn_reshape_convert_nc_f32_qp8(xnn_operator_t convert_op, convert_op->batch_size = batch_size; - const struct xnn_gemm_config* gemm_config = - xnn_init_qp8_f32_qc4w_gemm_config(); - const uint32_t mr_packed = batch_size == 1 ? 1 : gemm_config->mr_packed; + const struct xnn_gemm_config* gemm_config = convert_op->gemm_config; + if (gemm_config == NULL) { + xnn_log_error("failed to setup %s operator: No GEMM config provided.", + xnn_operator_type_to_string(convert_op->type)); + return xnn_status_invalid_parameter; + } + const uint32_t mr_packed = batch_size == 1 ? 1 + : gemm_config->mr_packed ? gemm_config->mr_packed + : gemm_config->mr; const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; @@ -941,16 +1044,18 @@ enum xnn_status xnn_reshape_convert_nc_f32_qp8(xnn_operator_t convert_op, .kr = kr, .sr = sr, .lhs_stride = input_stride * sizeof(float), + .group_stride = xnn_x8_packq_f32qp8_packed_size(batch_size, channels, + mr_packed, kr, sr), .packq_ukernel = (xnn_x8_packq_f32qp8_ukernel_fn) convert_op->unary_elementwise_config->ukernel, }; - // TODO(b/340399245) - Ideally, this should parallelize along `batch` in - // groups of `mr`. - convert_op->compute[0].type = xnn_parallelization_type_1d; - convert_op->compute[0].task_1d = - (pthreadpool_task_1d_t)xnn_compute_f32_qp8_convert; - convert_op->compute[0].range[0] = batch_size; + convert_op->compute[0].type = xnn_parallelization_type_2d_tile_1d; + convert_op->compute[0].task_2d_tile_1d = + (pthreadpool_task_2d_tile_1d_t)xnn_compute_f32_qp8_convert; + convert_op->compute[0].range[0] = num_groups; + convert_op->compute[0].range[1] = batch_size; + convert_op->compute[0].tile[0] = mr_packed; convert_op->state = xnn_run_state_needs_setup; @@ -1011,15 +1116,16 @@ enum xnn_status xnn_reshape_copy_nc_x32( threadpool); } -enum xnn_status xnn_setup_convert_nc_f16_qd8( +enum xnn_status setup_convert_nc_f16_qx8( xnn_operator_t convert_op, const void* input, - int8_t* output, + void* output, + enum xnn_operator_type expected_operator_type, struct xnn_quantization_params* quantization_params) { - if (convert_op->type != xnn_operator_type_convert_nc_f16_qd8) { + if (convert_op->type != expected_operator_type) { xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_convert_nc_f16_qd8), + xnn_operator_type_to_string(expected_operator_type), xnn_operator_type_to_string(convert_op->type)); return xnn_status_invalid_parameter; } @@ -1047,15 +1153,16 @@ enum xnn_status xnn_setup_convert_nc_f16_qd8( return xnn_status_success; } -enum xnn_status xnn_setup_convert_nc_f32_qd8( +enum xnn_status setup_convert_nc_f32_qx8( xnn_operator_t convert_op, const float* input, - int8_t* output, + void* output, + enum xnn_operator_type expected_operator_type, struct xnn_quantization_params* quantization_params) { - if (convert_op->type != xnn_operator_type_convert_nc_f32_qd8) { + if (convert_op->type != expected_operator_type) { xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)", - xnn_operator_type_to_string(xnn_operator_type_convert_nc_f32_qd8), + xnn_operator_type_to_string(expected_operator_type), xnn_operator_type_to_string(convert_op->type)); return xnn_status_invalid_parameter; } @@ -1078,11 +1185,48 @@ enum xnn_status xnn_setup_convert_nc_f32_qd8( convert_op->context.f32_qd8_convert.x = input; convert_op->context.f32_qd8_convert.y = output; convert_op->context.f32_qd8_convert.quantization_params = (struct xnn_qd8_quantization_params*) quantization_params; + convert_op->state = xnn_run_state_ready; return xnn_status_success; } +enum xnn_status xnn_setup_convert_nc_f16_qd8( + xnn_operator_t convert_op, + const void* input, + int8_t* output, + struct xnn_quantization_params* quantization_params) +{ + return setup_convert_nc_f16_qx8(convert_op, input, output, xnn_operator_type_convert_nc_f16_qd8, quantization_params); +} + +enum xnn_status xnn_setup_convert_nc_f16_qdu8( + xnn_operator_t convert_op, + const void* input, + uint8_t* output, + struct xnn_quantization_params* quantization_params) +{ + return setup_convert_nc_f16_qx8(convert_op, input, output, xnn_operator_type_convert_nc_f16_qdu8, quantization_params); +} + +enum xnn_status xnn_setup_convert_nc_f32_qd8( + xnn_operator_t convert_op, + const float* input, + int8_t* output, + struct xnn_quantization_params* quantization_params) +{ + return setup_convert_nc_f32_qx8(convert_op, input, output, xnn_operator_type_convert_nc_f32_qd8, quantization_params); +} + +enum xnn_status xnn_setup_convert_nc_f32_qdu8( + xnn_operator_t convert_op, + const float* input, + uint8_t* output, + struct xnn_quantization_params* quantization_params) +{ + return setup_convert_nc_f32_qx8(convert_op, input, output, xnn_operator_type_convert_nc_f32_qdu8, quantization_params); +} + enum xnn_status xnn_setup_convert_nc_f32_qp8(xnn_operator_t convert_op, const float* input, int8_t* output) { @@ -1253,4 +1397,4 @@ enum xnn_status xnn_run_convert_nc_f32_f16( xnn_datatype_fp16, NULL, NULL, NULL, flags, batch_size, channels, input_stride, output_stride, threadpool, input, output); -} \ No newline at end of file +} diff --git a/src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c b/src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c new file mode 100644 index 000000000000..cd8afb04b536 --- /dev/null +++ b/src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the `kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa` +// GEMM microkernel with a name that is compatible with our tooling. +void xnn_pf32_gemm_minmax_ukernel_1x32__neonsme2( + size_t m, size_t n, size_t k, const void* lhs_packed, size_t lhs_stride, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla( + m, n, k / sizeof(float), lhs_packed, lhs_stride, rhs_packed, dst, dst_stride_row, /*dst_stride_col=*/sizeof(float), + minmax_params->scalar.min, minmax_params->scalar.max); +#else + assert( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`." && 0); +#endif // XNN_ENABLE_KLEIDIAI +} + diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c index f4e56c3f4de9..a35209f59ac7 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -152,35 +150,35 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -239,25 +237,25 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c index a728f742964e..e188665069da 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c @@ -97,18 +97,16 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -151,35 +149,35 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -237,25 +235,25 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c index 90a93d0f34c6..db0fc20178de 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -152,35 +150,35 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -239,25 +237,25 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni.c index 05680a983039..39d3e7f4aa36 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnni.c @@ -97,18 +97,16 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -151,35 +149,35 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -237,25 +235,25 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c index 9f4989188976..b3a662a02ffe 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -154,35 +152,35 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -239,25 +237,25 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c index 5840e78a6376..465de4c34791 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c @@ -97,18 +97,16 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -153,35 +151,35 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -237,25 +235,25 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c index c840af54e5a6..5eb1c4b5f800 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -172,41 +170,41 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -273,29 +271,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c index 616ccfae81cc..fcf2d2a58e71 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c @@ -109,20 +109,18 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -171,41 +169,41 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -271,29 +269,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c index 3d2a62f28800..fac574704e60 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -172,41 +170,41 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -273,29 +271,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni.c index 8f0df4e6dd3e..6ea847cad5d4 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnni.c @@ -109,20 +109,18 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -171,41 +169,41 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -271,29 +269,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c index bf226ab32c24..6fd37617a30f 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -174,41 +172,41 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -273,29 +271,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c index c19ee737b211..2acb70d7df8d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c @@ -109,20 +109,18 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -173,41 +171,41 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -271,29 +269,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c index 1dc2d14f1040..51e5ac0d5d0f 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -192,47 +190,47 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -307,33 +305,33 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c index b518017e3d31..dfc66d1aa586 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c @@ -121,22 +121,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -191,47 +189,47 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -305,33 +303,33 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c index 32d6efdef216..4ba3883ae12f 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -192,47 +190,47 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -307,33 +305,33 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni.c index e471ecc66cc5..3726fc975043 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnni.c @@ -121,22 +121,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -191,47 +189,47 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -305,33 +303,33 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c index d30a2c8c16ba..31cebe7abc3c 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -194,47 +192,47 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -307,33 +305,33 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c index c31f5e3797c1..40904c94162b 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c @@ -121,22 +121,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -193,47 +191,47 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -305,33 +303,33 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c index ffc9f11701e8..0c2984445b81 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd.c index c478d7807811..c12ec2f3b6a5 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx2-madd.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c index 7f34e1f615dd..74190a8ce179 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c index a0478be56097..a91df0e3e9f5 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c index 83f98b8f0612..71911e8edf44 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni.c index 614cd4594b05..59a94ee9ff5a 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c index 4c2c5a20b7ef..eca821aebacc 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -66,8 +64,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c index 7ea6a0ba27e6..8b09f9674e63 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -65,8 +63,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c index ee3c1edb5b85..c6d6ba63a33a 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni.c index 8e83d4d0927a..57e44617f15d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avxvnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c index 5720147ca435..200162f28a48 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -76,11 +74,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -107,9 +105,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd.c index 8298b0488443..e09388109f0f 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avx2-madd.c @@ -49,10 +49,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -75,11 +73,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -105,9 +103,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c index e44eb26da516..8648b9e1865e 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -76,11 +74,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -107,9 +105,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni.c index fd68dff21a8c..a528dd5ac45f 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-2x8c8-minmax-avxvnni.c @@ -49,10 +49,8 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -75,11 +73,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -105,9 +103,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c index 3056fe7fcd8d..06d4a38419aa 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -82,14 +80,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -120,11 +118,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd.c index 953a1051ffe0..225a495c6f17 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avx2-madd.c @@ -55,11 +55,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -81,14 +79,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -118,11 +116,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c index ea90a208317b..886d4a77c522 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -82,14 +80,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -120,11 +118,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni.c index 60eced132c03..0971f0d716d6 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-3x8c8-minmax-avxvnni.c @@ -55,11 +55,9 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -81,14 +79,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -118,11 +116,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c index e45800b24e06..0c737d0f6a4d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -92,17 +90,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -137,13 +135,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd.c index a84b416743cb..b8e80471b00d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avx2-madd.c @@ -61,12 +61,10 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -91,17 +89,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -135,13 +133,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c index 3611d4e60176..eac6c977fb6e 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -92,17 +90,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -137,13 +135,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni.c index 94a449d57794..05fe9b5569e5 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-4x8c8-minmax-avxvnni.c @@ -61,12 +61,10 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -91,17 +89,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -135,13 +133,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c index 935204f54751..f64d0d00c034 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd.c index 246d452f6877..560c4b49c09d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx2-madd.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c index e551b0833a3a..c644a2a0786a 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c index f8811365dcbc..7707df9ee4e3 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c index a7804b706b9b..d51cb1eef601 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni.c index 2078e4944377..c19c4332b5ac 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c index 67aab3f75faa..da98e4b36d21 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -104,20 +102,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c index 83614c77f612..5e069dc7d14f 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -103,20 +101,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c index 4237189d3c59..b71623093ae6 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni.c index 46aa114ada2c..42cead731c43 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-5x8c8-minmax-avxvnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c index e637571acadc..fda203fa7b60 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c @@ -74,14 +74,12 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -112,23 +110,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -171,17 +169,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd.c index 309c6df8f05e..13502c52e6f5 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avx2-madd.c @@ -73,14 +73,12 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -111,23 +109,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -169,17 +167,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c index cd0841562578..686fc777fddd 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c @@ -74,14 +74,12 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -112,23 +110,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -171,17 +169,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni.c index 86ecad55b519..a6812c81b8a9 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-6x8c8-minmax-avxvnni.c @@ -73,14 +73,12 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -111,23 +109,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -169,17 +167,17 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c index 52ce80f792d6..c83f78976b34 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd.c index fc1427b75bc1..0f4d51049308 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx2-madd.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c index 4e64355c0836..08d820654aea 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c index b86d10ac868e..0627f62c501e 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c index f0bf7b6e1e98..fa90f9f0a00e 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni.c index 44741723c120..2334958fb912 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c index 1de2e12c791d..c425b932604d 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -124,26 +122,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c index ce8654739882..82d6654ff7ec 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -123,26 +121,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c index 717feedc80a3..7782e133ab54 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni.c index fc9ecfa57808..bf56d5add43a 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-7x8c8-minmax-avxvnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c index 1680d1754a9b..d849fa4b9f74 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd.c index b678856a4b56..190423f40ec7 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx2-madd.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c index e354565fe185..e832712a3903 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c index 892899eba182..60a2183ebef9 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c index 56a2da4a1f38..0df6f030bd3c 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni.c index 1694ff3d9e6e..f8e9b5e5aba1 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c index 8c0a173ea616..3d2d027817d2 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -134,29 +132,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c index 535a8f918899..fed1619dec72 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,29 +131,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c index 31264b392122..c384328e7dc4 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni.c index 309356792012..f527fd8ab903 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-8x8c8-minmax-avxvnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c index 22ba6dff37f0..c8954510e024 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,32 +140,32 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -222,23 +220,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c index 5a577441340d..0955f7698745 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c @@ -91,17 +91,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,32 +139,32 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -220,23 +218,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c index 6853dbde7c27..31e6169fcbdd 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,32 +140,32 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -222,23 +220,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni.c index c4fbeccba1e1..752b013fe87c 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnni.c @@ -91,17 +91,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,32 +139,32 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -220,23 +218,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c index 0d5712cead48..a844f7d192e8 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -144,32 +142,32 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -222,23 +220,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c index 6356a61dbf90..f8a7ac7eda92 100644 --- a/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c @@ -91,17 +91,15 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -143,32 +141,32 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -220,23 +218,23 @@ void xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c index 6084581e6d84..06cd2d70ffdb 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -150,35 +148,35 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -234,25 +232,25 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni.c index 4274d2d0e7eb..d6d84cdf03da 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-10x8c8-minmax-avx256vnni.c @@ -97,18 +97,16 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -149,35 +147,35 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -231,25 +229,25 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c index 957239fd54bf..909112cae970 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -170,41 +168,41 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -268,29 +266,29 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni.c index 15c6148a937e..eaa353b4585a 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-12x8c8-minmax-avx256vnni.c @@ -109,20 +109,18 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -169,41 +167,41 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -265,29 +263,29 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c index d979bd0c9114..6d5413db4556 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -190,47 +188,47 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -302,33 +300,33 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni.c index 47956b996c75..0693cbc59f9d 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-14x8c8-minmax-avx256vnni.c @@ -121,22 +121,20 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -189,47 +187,47 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -299,33 +297,33 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c index c8406f4db238..55b0a1ee1b47 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,10 +49,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +68,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -282,12 +284,22 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -352,70 +364,71 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx.c index 691be59a91f6..c6fb6a8c4c64 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-16x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -201,12 +203,22 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -271,70 +283,71 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x64c4-minmax-avx512amx.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x64c4-minmax-avx512amx.c index 9f28afbaa5c5..9ef21746cb14 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x64c4-minmax-avx512amx.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -141,20 +143,31 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c index e9599fb475b3..ca080afd4358 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -83,7 +81,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni.c index b318b05d9101..3ea630a7c8b6 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -61,8 +59,8 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -80,7 +78,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c index 7fa678cc6a48..19538be00d28 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -83,7 +81,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni.c index 34f293a43b54..4458a2b37b98 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avxvnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni( const int8_t* a0 = a; uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -61,8 +59,8 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -80,7 +78,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c index 265ac6acdfdb..7b7e9103fbc7 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -74,11 +72,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -102,9 +100,9 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni.c index fa6378a37572..c177d81cbb5b 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-2x8c8-minmax-avxvnni.c @@ -49,10 +49,8 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -73,11 +71,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -99,9 +97,9 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c index 3b88a6a6ba79..e6067caee892 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -80,14 +78,14 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -115,11 +113,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni.c index ee9fca470f72..8e64a5d8d672 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-3x8c8-minmax-avxvnni.c @@ -55,11 +55,9 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -79,14 +77,14 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -112,11 +110,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c index 5c6427cb425f..4290781ca32b 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -90,17 +88,17 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -132,13 +130,13 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni.c index 3ace0e24f6da..a13044743c39 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c8-minmax-avxvnni.c @@ -61,12 +61,10 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -89,17 +87,17 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -129,13 +127,13 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c index 89072ac9828d..b5b4c351472c 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -100,20 +98,20 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -149,15 +147,15 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni.c index a9bc59793c8e..6a6c474dbafb 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avx256vnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -99,20 +97,20 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -146,15 +144,15 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c index d7dce3c91393..b840a4396a4e 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -100,20 +98,20 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -149,15 +147,15 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni.c index bfe58840f084..b38161c4fcbc 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-5x8c8-minmax-avxvnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -99,20 +97,20 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -146,15 +144,15 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c index 9a8b858da432..684784539e9a 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c @@ -74,14 +74,12 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -110,23 +108,23 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -166,17 +164,17 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni.c index b67e709fee4b..615f62fc003d 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-6x8c8-minmax-avxvnni.c @@ -73,14 +73,12 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -109,23 +107,23 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -163,17 +161,17 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x64c4-minmax-avx512amx.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x64c4-minmax-avx512amx.c index aef907ff7a9f..41cfff795afc 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x64c4-minmax-avx512amx.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -165,12 +167,22 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -199,34 +211,35 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c index 4a441d232771..df687512e90a 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -120,26 +118,26 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -183,19 +181,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni.c index 3f30b4394e8e..5ad717cff3a0 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avx256vnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -119,26 +117,26 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -180,19 +178,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c index 74251e8decd1..a3a80f5f7f12 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -120,26 +118,26 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -183,19 +181,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni.c index b8930384e209..25bb11c23c60 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-7x8c8-minmax-avxvnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -119,26 +117,26 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -180,19 +178,19 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c index 34a93f070331..6b2fbf87460e 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -130,29 +128,29 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -200,21 +198,21 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni.c index 5a0e3bf51b16..57c7c40c7c61 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -129,29 +127,29 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -197,21 +195,21 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c index 3a48fdb22833..72592c437874 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -130,29 +128,29 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -200,21 +198,21 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni.c index d7a39cdd7e4e..df6e675cd36d 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avxvnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -129,29 +127,29 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -197,21 +195,21 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c index 9041fc4304bd..5cb280716a9c 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -140,32 +138,32 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -217,23 +215,23 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni.c index e693f2a59c1c..154df8421738 100644 --- a/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-9x8c8-minmax-avx256vnni.c @@ -91,17 +91,15 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); // XNN_FORCE_REALIZATION(voutput_min); @@ -139,32 +137,32 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -214,23 +212,23 @@ void xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c index a8207569898d..a200dfae4c3a 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c @@ -83,9 +83,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -184,35 +182,35 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -268,25 +266,25 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni.c index 8b55f1148468..90a13af3e134 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-10x8c8-minmax-avx256vnni.c @@ -82,9 +82,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -183,35 +181,35 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -265,25 +263,25 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c index 249d35fd8d21..4d97009aca23 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c @@ -91,9 +91,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -208,41 +206,41 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -306,29 +304,29 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni.c index 4992b1ea97b3..aa249e287bf5 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-12x8c8-minmax-avx256vnni.c @@ -90,9 +90,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -207,41 +205,41 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -303,29 +301,29 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c index aad279e726ce..c5a3dc8dd227 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c @@ -99,9 +99,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -232,47 +230,47 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -344,33 +342,33 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni.c index 5ddd1748b896..4aeb391884ad 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-14x8c8-minmax-avx256vnni.c @@ -98,9 +98,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -231,47 +229,47 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -341,33 +339,33 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c index ac149df2754c..371e5fdb9736 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -47,10 +52,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -70,19 +72,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -583,12 +585,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -653,70 +666,70 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx.c index be91c8cd64a9..92b13f83aff0 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-16x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,10 +51,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -69,19 +71,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -422,12 +424,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -492,70 +505,70 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x64c4-minmax-avx512amx.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x64c4-minmax-avx512amx.c index 8a3cd9fe639d..72cc099f8ec8 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x64c4-minmax-avx512amx.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,10 +51,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -69,19 +71,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -162,20 +164,31 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c index f50d45f303c1..8034b4b7cbb5 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c @@ -47,9 +47,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm( kc = round_up_po2(kc, 8 * sizeof(int8_t)); uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -78,8 +76,8 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -99,7 +97,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni.c index 5817541fd158..91eec4369978 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni.c @@ -46,9 +46,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni( kc = round_up_po2(kc, 8 * sizeof(int8_t)); uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -77,8 +75,8 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -96,7 +94,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c index 3867b7ae9561..c05d8d2006d5 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c @@ -47,9 +47,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm( kc = round_up_po2(kc, 8 * sizeof(int8_t)); uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -78,8 +76,8 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -99,7 +97,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni.c index b55db8a80efb..35f0bb51131f 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avxvnni.c @@ -46,9 +46,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni( kc = round_up_po2(kc, 8 * sizeof(int8_t)); uint16_t* c0 = (uint16_t*) c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -77,8 +75,8 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -96,7 +94,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c index 8a5abdbf3d66..a3288a490251 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c @@ -51,9 +51,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -92,11 +90,11 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -120,9 +118,9 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni.c index a55d1ec8e86d..9a958aa22382 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-2x8c8-minmax-avxvnni.c @@ -50,9 +50,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -91,11 +89,11 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -117,9 +115,9 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c index 1de92e56ba00..04a59d8dae80 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c @@ -55,9 +55,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -100,14 +98,14 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -135,11 +133,11 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni.c index 0f6fca318616..dd70e1969969 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-3x8c8-minmax-avxvnni.c @@ -54,9 +54,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -99,14 +97,14 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -132,11 +130,11 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c index 02738b78b5d1..ed737ad590f5 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c @@ -59,9 +59,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -112,17 +110,17 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -154,13 +152,13 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni.c index 8a568270b8bb..a84c636434db 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c8-minmax-avxvnni.c @@ -58,9 +58,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -111,17 +109,17 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -151,13 +149,13 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c index 05d59d382ad8..959ccd9ad7bb 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c @@ -63,9 +63,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -124,20 +122,20 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -173,15 +171,15 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni.c index e3e28ce7afc6..04634b34d73b 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avx256vnni.c @@ -62,9 +62,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -123,20 +121,20 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -170,15 +168,15 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c index 8bd640f7f549..9426e8fdc5c7 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c @@ -63,9 +63,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -124,20 +122,20 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -173,15 +171,15 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni.c index 9cfcb9df14ac..a647958c1f33 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-5x8c8-minmax-avxvnni.c @@ -62,9 +62,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -123,20 +121,20 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -170,15 +168,15 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c index af3d2eda052d..7f200fff8bd9 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c @@ -67,9 +67,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -136,23 +134,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -192,17 +190,17 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni.c index 83326c449c5c..78fdec95a31a 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-6x8c8-minmax-avxvnni.c @@ -66,9 +66,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -135,23 +133,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -189,17 +187,17 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x64c4-minmax-avx512amx.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x64c4-minmax-avx512amx.c index 6d7808a7f7af..4175981a69b6 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x64c4-minmax-avx512amx.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,10 +51,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -69,19 +71,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -278,12 +280,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -312,34 +325,34 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c index b7944d354e24..57b4244f5241 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c @@ -71,9 +71,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -148,26 +146,26 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -211,19 +209,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni.c index 4d2e9bb1226e..7dec530edf0f 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avx256vnni.c @@ -70,9 +70,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -147,26 +145,26 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -208,19 +206,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c index 37408fe3ba2d..3d9988bb08ec 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c @@ -71,9 +71,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -148,26 +146,26 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -211,19 +209,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni.c index 1af16be34a51..81bbfc61a0c7 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-7x8c8-minmax-avxvnni.c @@ -70,9 +70,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -147,26 +145,26 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -208,19 +206,19 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c index 8b244632de7f..f35b9f0a7b0b 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c @@ -75,9 +75,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -160,29 +158,29 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -230,21 +228,21 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c index 02da8f8df3ed..5758a99dea3b 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c @@ -74,9 +74,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -159,29 +157,29 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -227,21 +225,21 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c index 6b3955ba189c..4e14458edcb3 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c @@ -75,9 +75,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -160,29 +158,29 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -230,21 +228,21 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni.c index 67fdbe65f544..22a4fd84346c 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avxvnni.c @@ -74,9 +74,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -159,29 +157,29 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -227,21 +225,21 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c index f1e67eba4269..f9cd708b8582 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c @@ -79,9 +79,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -172,32 +170,32 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -249,23 +247,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni.c b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni.c index e2705f243050..b3799a20b871 100644 --- a/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni.c +++ b/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-9x8c8-minmax-avx256vnni.c @@ -78,9 +78,7 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -171,32 +169,32 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -246,23 +244,23 @@ void xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni-prfm.c index 854b9ebc3b80..a8608a53af2a 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni-prfm.c @@ -104,18 +104,16 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -159,35 +157,35 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( __m512i vacc9x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -247,25 +245,25 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni.c index ca4d31a9f1eb..a8ac3e8c86b1 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnni.c @@ -103,18 +103,16 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -158,35 +156,35 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni( __m512i vacc9x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -244,25 +242,25 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c index 97e3b7d35278..d4e83d8911ee 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c @@ -104,18 +104,16 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -161,35 +159,35 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm( __m512i vacc9x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -247,25 +245,25 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni.c index f23d806484cc..4166ea810d06 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-10x16c8-minmax-avx512vnnigfni.c @@ -103,18 +103,16 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -160,35 +158,35 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni( __m512i vacc9x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -244,25 +242,25 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni-prfm.c index 1870c0e3b04d..c96e979790d2 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni-prfm.c @@ -116,20 +116,18 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -179,41 +177,41 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( __m512i vacc11x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -281,29 +279,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni.c index d1ac24ee6042..0a68a009ee87 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnni.c @@ -115,20 +115,18 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -178,41 +176,41 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni( __m512i vacc11x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -278,29 +276,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c index 5a2e45faa778..0378742420b1 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c @@ -116,20 +116,18 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -181,41 +179,41 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm( __m512i vacc11x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -281,29 +279,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni.c index bfd23dc7179f..e6faa434d543 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-12x16c8-minmax-avx512vnnigfni.c @@ -115,20 +115,18 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -180,41 +178,41 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni( __m512i vacc11x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -278,29 +276,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni-prfm.c index d5e3d447c9df..ac09a18af160 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni-prfm.c @@ -128,22 +128,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); - const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point + 128); - const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); + const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point); + const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -199,47 +197,47 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( __m512i vacc13x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -315,33 +313,33 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni.c index b58b4f6f88be..bcd56c5415c1 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnni.c @@ -127,22 +127,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); - const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point + 128); - const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); + const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point); + const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -198,47 +196,47 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni( __m512i vacc13x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -312,33 +310,33 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c index b2e9a30b539d..25eda1ae3dc1 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c @@ -128,22 +128,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); - const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point + 128); - const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); + const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point); + const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -201,47 +199,47 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm( __m512i vacc13x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -315,33 +313,33 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni.c index c095ea3de7d9..ff78e319ea50 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-14x16c8-minmax-avx512vnnigfni.c @@ -127,22 +127,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); - const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point + 128); - const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point + 128); - const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point + 128); - const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point + 128); - const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); + const __m512 vinput_zero_point9 = _mm512_set1_ps((float) quantization_params[9].zero_point); + const __m512 vinput_zero_point10 = _mm512_set1_ps((float) quantization_params[10].zero_point); + const __m512 vinput_zero_point11 = _mm512_set1_ps((float) quantization_params[11].zero_point); + const __m512 vinput_zero_point12 = _mm512_set1_ps((float) quantization_params[12].zero_point); + const __m512 vinput_zero_point13 = _mm512_set1_ps((float) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -200,47 +198,47 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni( __m512i vacc13x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -312,33 +310,33 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni-prfm.c index f6b5dcaf1d7b..50a3d098021f 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni-prfm.c @@ -50,9 +50,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -71,8 +69,8 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( __m512i vacc1x0x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -96,7 +94,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni.c index 15e6e8eedd4c..d6380a56d31b 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnni.c @@ -49,9 +49,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -70,8 +68,8 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni( __m512i vacc1x0x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -93,7 +91,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c index 0bb3142f0a9d..41acc70b196e 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c @@ -50,9 +50,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -73,8 +71,8 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm( __m512i vacc1x0x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -96,7 +94,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni.c index ca985d53515e..ace946fc90e3 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-1x16c8-minmax-avx512vnnigfni.c @@ -49,9 +49,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -72,8 +70,8 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni( __m512i vacc1x0x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -93,7 +91,7 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni-prfm.c index e56094bdbf2b..301cc0253321 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni-prfm.c @@ -74,13 +74,11 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -109,20 +107,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( __m512i vacc4x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -162,15 +160,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni.c index 181b677f3c87..c8634aeb4a4b 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnni.c @@ -73,13 +73,11 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -108,20 +106,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni( __m512i vacc4x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -159,15 +157,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c index 08a0d21d8891..7373421ed78f 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c @@ -74,13 +74,11 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -111,20 +109,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm( __m512i vacc4x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -162,15 +160,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni.c index 303a049ef708..2786103c7e3a 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-5x16c8-minmax-avx512vnnigfni.c @@ -73,13 +73,11 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -110,20 +108,20 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni( __m512i vacc4x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -159,15 +157,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni-prfm.c index ca6cbba343df..81a37a179f7a 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni-prfm.c @@ -86,15 +86,13 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -129,26 +127,26 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( __m512i vacc6x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -196,19 +194,19 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni.c index c3d25a98541f..46d9763d8c5d 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnni.c @@ -85,15 +85,13 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -128,26 +126,26 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni( __m512i vacc6x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -193,19 +191,19 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c index 2632d69e01e8..be1401479799 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c @@ -86,15 +86,13 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,26 +129,26 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm( __m512i vacc6x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -196,19 +194,19 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni.c index 5ce18737b282..d8c64620ce02 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-7x16c8-minmax-avx512vnnigfni.c @@ -85,15 +85,13 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -130,26 +128,26 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni( __m512i vacc6x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -193,19 +191,19 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni-prfm.c index 78d98538653f..ad79a6627b83 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni-prfm.c @@ -92,16 +92,14 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -139,29 +137,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( __m512i vacc7x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -213,21 +211,21 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni.c index fa80e70222e7..2a1dcc545c44 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni.c @@ -91,16 +91,14 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -138,29 +136,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni( __m512i vacc7x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -210,21 +208,21 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c index 0d873cfed8a3..e5edaa49a652 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c @@ -92,16 +92,14 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,29 +139,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm( __m512i vacc7x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -213,21 +211,21 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni.c index be68aca4e0c1..99df746e2764 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnnigfni.c @@ -91,16 +91,14 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -140,29 +138,29 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni( __m512i vacc7x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -210,21 +208,21 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni-prfm.c index f6270d906e77..1f6a726fd2c5 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni-prfm.c @@ -98,17 +98,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -149,32 +147,32 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( __m512i vacc8x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -230,23 +228,23 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni.c index d8af343bb0e1..757c3dc3f541 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnni.c @@ -97,17 +97,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -148,32 +146,32 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni( __m512i vacc8x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -227,23 +225,23 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c index c881f0fa213f..0fd9dd3af8b0 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c @@ -98,17 +98,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -151,32 +149,32 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm( __m512i vacc8x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -230,23 +228,23 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni.c index c31c81e7a27d..b7ef5f35e291 100644 --- a/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-9x16c8-minmax-avx512vnnigfni.c @@ -97,17 +97,15 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni( assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point + 128); - const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point + 128); - const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point + 128); - const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point + 128); - const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point + 128); - const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point + 128); - const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point + 128); - const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point + 128); - const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point + 128); + const __m512 vinput_zero_point0 = _mm512_set1_ps((float) quantization_params[0].zero_point); + const __m512 vinput_zero_point1 = _mm512_set1_ps((float) quantization_params[1].zero_point); + const __m512 vinput_zero_point2 = _mm512_set1_ps((float) quantization_params[2].zero_point); + const __m512 vinput_zero_point3 = _mm512_set1_ps((float) quantization_params[3].zero_point); + const __m512 vinput_zero_point4 = _mm512_set1_ps((float) quantization_params[4].zero_point); + const __m512 vinput_zero_point5 = _mm512_set1_ps((float) quantization_params[5].zero_point); + const __m512 vinput_zero_point6 = _mm512_set1_ps((float) quantization_params[6].zero_point); + const __m512 vinput_zero_point7 = _mm512_set1_ps((float) quantization_params[7].zero_point); + const __m512 vinput_zero_point8 = _mm512_set1_ps((float) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -150,32 +148,32 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni( __m512i vacc8x89ABCDEF = _mm512_setzero_epi32(); size_t k = bl; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_loadu_si512(w); @@ -227,23 +225,23 @@ void xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd-prfm.c index a12722ffe7eb..8680d2de9fac 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -143,35 +141,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -216,25 +214,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd_prfm( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd.c index a392feafa75f..f2f54c2aaff8 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512skx-madd.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,35 +140,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -214,25 +212,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni-prfm.c index 2e443ec8ca91..f477a6cdd664 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -143,35 +141,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -216,25 +214,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni.c index d1bfa4132d32..e6f9cf18b90d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnni.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,35 +140,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -214,25 +212,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni-prfm.c index 52e593dc9653..9425fa6e4da6 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -145,35 +143,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -217,25 +215,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni_prfm( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni.c index ce673489c1ef..46eed490aae1 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c4-minmax-avx512vnnigfni.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -144,35 +142,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -215,25 +213,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd-prfm.c index 88d44581095d..0032770b5cea 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -153,35 +151,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -241,25 +239,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd.c index 9f06edacb183..6a4acc010f67 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512skx-madd.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -152,35 +150,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -238,25 +236,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni-prfm.c index 9272fa683011..ca71475db229 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -153,35 +151,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -241,25 +239,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni.c index b8be029505cd..c57e21d3c66b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnni.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -152,35 +150,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -238,25 +236,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c index 53b55900a3ca..94e838d64b9d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -155,35 +153,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -241,25 +239,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni.c index 1b3d72ea01ec..d89612094063 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x16c8-minmax-avx512vnnigfni.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -154,35 +152,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -238,25 +236,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c index 1fe78f01f70e..acbbb3bedc63 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -152,35 +150,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -239,25 +237,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c index f1922a0ffe07..b3c975b2ea97 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256skx-madd.c @@ -97,18 +97,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -151,35 +149,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -237,25 +235,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c index bd135b1fb055..4493afd65263 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -152,35 +150,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -239,25 +237,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni.c index fa61cf53c228..b243b43c2378 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnni.c @@ -97,18 +97,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -151,35 +149,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -237,25 +235,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c index ff65b8709c9e..e43ccc1bc049 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -154,35 +152,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -239,25 +237,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c index ed00b3cb888d..e4466ea4a924 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-10x8c8-minmax-avx256vnnigfni.c @@ -97,18 +97,16 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -153,35 +151,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -237,25 +235,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd-prfm.c index 2586b94c79fc..ff68dd6edfa9 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -161,41 +159,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -246,29 +244,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd_prfm( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd.c index 4d826831b8a1..db28020ca28e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512skx-madd.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -160,41 +158,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -244,29 +242,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni-prfm.c index 3181fcec00cd..ae9d301b144f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -161,41 +159,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -246,29 +244,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni.c index 22f631820b0f..2cc0954d295f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnni.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -160,41 +158,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -244,29 +242,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni-prfm.c index ef641fdb4e4b..0fb2ea6fa308 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -163,41 +161,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -247,29 +245,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni_prfm( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni.c index d6a5a16d55c2..f1b5e5e539b5 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c4-minmax-avx512vnnigfni.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -162,41 +160,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -245,29 +243,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd-prfm.c index 43c0f549d0aa..0adc720f9e1f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -173,41 +171,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -275,29 +273,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd.c index 17928d6df8bb..8e67a8f3c9cd 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512skx-madd.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -172,41 +170,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -272,29 +270,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni-prfm.c index 7b0e0a7765de..20a25ece6085 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -173,41 +171,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -275,29 +273,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni.c index 3785d660872b..53c0d8eb064c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnni.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -172,41 +170,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -272,29 +270,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c index cbd53f54dd3e..9cc12374d161 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -175,41 +173,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -275,29 +273,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni.c index 3a39e3baaabd..e766a22bb5e6 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x16c8-minmax-avx512vnnigfni.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -174,41 +172,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -272,29 +270,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c index e9f511fdf6b3..9fa4d23ad5de 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -172,41 +170,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -273,29 +271,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c index fd556197caee..f89b54d89212 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256skx-madd.c @@ -109,20 +109,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -171,41 +169,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -271,29 +269,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c index 099b94b5198d..2df056c8909f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -172,41 +170,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -273,29 +271,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni.c index b27f0e8a3ac6..0ec92d6a6e08 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnni.c @@ -109,20 +109,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -171,41 +169,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -271,29 +269,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c index 99569ff0fa9d..b0b5dd4f05af 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -174,41 +172,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -273,29 +271,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c index 4a0f9b013ad9..b3a2ba7654f2 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-12x8c8-minmax-avx256vnnigfni.c @@ -109,20 +109,18 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -173,41 +171,41 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -271,29 +269,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd-prfm.c index 26aa27536f3c..b686df2b38cc 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -179,47 +177,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -276,33 +274,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd_prfm( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd.c index 7e641ae0a430..230d460f6d2e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512skx-madd.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -178,47 +176,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -274,33 +272,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni-prfm.c index 507a5de8fb5f..bb5d0068a022 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -179,47 +177,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -276,33 +274,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni.c index 59bbc98600ea..cb3156405497 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnni.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -178,47 +176,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -274,33 +272,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni-prfm.c index 2f5f40d1b36e..c0e051b7a0e4 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -181,47 +179,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -277,33 +275,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni_prfm( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni.c index 799cd56d121c..980071282e1e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c4-minmax-avx512vnnigfni.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -180,47 +178,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -275,33 +273,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd-prfm.c index 93f306e72434..d17a36e23e58 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -193,47 +191,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -309,33 +307,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd.c index 1cad675b6d90..40aa60f7420f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512skx-madd.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -192,47 +190,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -306,33 +304,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni-prfm.c index c3b084c62070..361f059b4de0 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -193,47 +191,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -309,33 +307,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni.c index 739109bbe2cf..e394ff97516e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -192,47 +190,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -306,33 +304,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c index e88e44605a01..c1738eb8894a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -195,47 +193,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -309,33 +307,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni.c index 0d3810029251..a95051cb0d4e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnnigfni.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -194,47 +192,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -306,33 +304,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c index 353e4ea8f0cc..66554209f36d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -192,47 +190,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -307,33 +305,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c index 26523deaaef5..f1abfd024e22 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256skx-madd.c @@ -121,22 +121,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -191,47 +189,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -305,33 +303,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c index 38081531791a..19503da37563 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -192,47 +190,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -307,33 +305,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni.c index 92afd7f471e6..ababb2b9026d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnni.c @@ -121,22 +121,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -191,47 +189,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -305,33 +303,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c index 74d3fc009ef6..a9e02f208ea2 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -194,47 +192,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -307,33 +305,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c index 44d0ba3bfba7..4573d2dcfa02 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x8c8-minmax-avx256vnnigfni.c @@ -121,22 +121,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -193,47 +191,47 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -305,33 +303,33 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx-prfm.c index a2843adba9a3..6eecdcfbac0b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,8 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +69,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -265,9 +270,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[2].zero_point)); @@ -284,22 +299,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[13].zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx.c index 253b49d611c4..a1dab7e3aa2b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -232,9 +237,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[2].zero_point)); @@ -251,22 +266,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx( __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[13].zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx-prfm.c index d0e2daed4922..5cd829555022 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,9 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +69,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -335,10 +339,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); @@ -371,38 +385,39 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx.c index 8859a233e625..0eebbfb3f771 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,9 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -286,10 +290,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); @@ -322,38 +336,39 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx( __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx-prfm.c index 87bc98c79b38..ec325c7e20dc 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,11 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +69,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -475,12 +477,22 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -545,70 +557,71 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx.c index 06bcf00ccfc2..f128cbe21284 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-16x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,11 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -394,12 +396,22 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -464,70 +476,71 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512amx.c index e7e20fc3a97b..eb64a5eff3db 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[1][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -172,11 +177,22 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd-prfm.c index 019dd8861108..d6b9ed4de90f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -81,7 +79,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd_prfm( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd.c index 7c7ed19b686f..6d0fc03c9ed4 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512skx-madd.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -61,8 +59,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -79,7 +77,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c index 3bab5f60d5b0..c92809261ee1 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -81,7 +79,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c index c5b1e338982a..ee6d0b60ad12 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnni.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -61,8 +59,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -79,7 +77,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni-prfm.c index 58a63181397f..5d865b0226d3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -82,7 +80,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni_prfm( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni.c index 20ed4483581d..e89d35c79ae8 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c4-minmax-avx512vnnigfni.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -80,7 +78,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd-prfm.c index e5fdb11d3c7f..d9673c0daa45 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -65,8 +63,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -90,7 +88,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd.c index a457b1d0f4b4..17a156bff3df 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512skx-madd.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -87,7 +85,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni-prfm.c index 5fbcb809a01a..113531babd38 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -65,8 +63,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -90,7 +88,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni.c index a0f559157c52..c561234af223 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -87,7 +85,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c index cf8a8261cea4..26fe85fecd15 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -67,8 +65,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -90,7 +88,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni.c index 6f261cecfb8a..8986508df774 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnnigfni.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -66,8 +64,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -87,7 +85,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x32c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x32c4-minmax-avx512amx.c index 8f6c2d664de9..68782a0f5808 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,9 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[2][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -226,14 +230,25 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c index 5146700b65f9..10c0527c54f0 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd_prfm( const int8_t* a0 = a; float* c0 = c; - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -62,8 +60,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd.c index 6fa983ecec10..9ecc05c4ef9c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-sse41-madd.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd( const int8_t* a0 = a; float* c0 = c; - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -61,8 +59,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -84,7 +82,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c index d90815a13dd9..c265eca5d73a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd_prfm( const int8_t* a0 = a; float* c0 = c; - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -66,8 +64,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -90,7 +88,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd.c index 66b17ab0492b..5b6cd0d5ff7e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x4c8-minmax-ssse3-madd.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd( const int8_t* a0 = a; float* c0 = c; - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -65,8 +63,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -88,7 +86,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x64c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x64c4-minmax-avx512amx.c index 2c7839bec9ee..68a75729a685 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,11 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -334,20 +336,31 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c index 11bcffac8e91..905e8734cc2d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd.c index 4aa62bc7f8ec..e7ebb5a34eb3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx2-madd.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c index 367e8c5021f4..4d193f1824a3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c index 5766701aa5b6..edb97089eb93 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256skx-madd.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c index 10ea9a4eb1e2..05101d5999ec 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni.c index c3aa9404a86b..3a44f262d2bf 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c index a165693fce17..b161beff9c63 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -66,8 +64,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c index 22fbf361afa8..1a3c71c40819 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avx256vnnigfni.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -65,8 +63,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c index 2626c6745eb6..e580ac13331e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -64,8 +62,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -88,7 +86,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni.c index f2c70074e844..a4c309e2faba 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x8c8-minmax-avxvnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c index 3f855ddf40d5..4150dde2cdfe 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd_prfm( c1 = c0; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -74,11 +72,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -105,9 +103,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd.c index 00e9ed98858f..fc00a1ac137e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-sse41-madd.c @@ -49,10 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd( c1 = c0; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -73,11 +71,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -103,9 +101,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c index f40f4ce05ccd..2a70c29f55ea 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd_prfm( c1 = c0; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -81,11 +79,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -112,9 +110,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd.c index 2bb4d7d63376..bc4e296286ac 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x4c8-minmax-ssse3-madd.c @@ -49,10 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd( c1 = c0; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -80,11 +78,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -110,9 +108,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c index 5629765fae5a..fda6979bf5bb 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -76,11 +74,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -107,9 +105,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd.c index b86fc9133da0..a073d638b5e7 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avx2-madd.c @@ -49,10 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -75,11 +73,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -105,9 +103,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c index 55da6fa36c49..f858c67b193d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -76,11 +74,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -107,9 +105,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni.c index cc75a2be7945..d8709f969769 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-2x8c8-minmax-avxvnni.c @@ -49,10 +49,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -75,11 +73,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -105,9 +103,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c index 2decfb0cc31f..02e865b9e35a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd_prfm( c2 = c1; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -80,14 +78,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -118,11 +116,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd.c index c296759b14ce..2650c61fa268 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-sse41-madd.c @@ -55,11 +55,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd( c2 = c1; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -79,14 +77,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -116,11 +114,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c index 1a9b62cde4b3..859d2deb5f70 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd_prfm( c2 = c1; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -90,14 +88,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -128,11 +126,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd.c index 4fc3dcab34cc..7ce95ac6837c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x4c8-minmax-ssse3-madd.c @@ -55,11 +55,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd( c2 = c1; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -89,14 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -126,11 +124,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c index 56e8196d7955..d0e3413b67fa 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -82,14 +80,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -120,11 +118,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd.c index a7da03298d33..a503cb056232 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avx2-madd.c @@ -55,11 +55,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -81,14 +79,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -118,11 +116,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c index 6817feae01ca..637090d2aca1 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -82,14 +80,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -120,11 +118,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni.c index 04ed588cb7cf..332bc14a3a94 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-3x8c8-minmax-avxvnni.c @@ -55,11 +55,9 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -81,14 +79,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -118,11 +116,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd-prfm.c index 228c6ab412a8..f47217976dce 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd-prfm.c @@ -63,12 +63,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd_prfm( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -89,17 +87,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -126,13 +124,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd_prfm( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd.c index 27401b480a00..90efc3b2db9c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512skx-madd.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -88,17 +86,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -124,13 +122,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c index cbadb225567e..0134fc14bcf3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni-prfm.c @@ -63,12 +63,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -89,17 +87,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -126,13 +124,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c index e579ba3f0f68..0dc76e12c7d5 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnni.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -88,17 +86,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -124,13 +122,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni-prfm.c index ff598e6e4509..eb6d1978f35c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni-prfm.c @@ -63,12 +63,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni_prfm( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -91,17 +89,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -127,13 +125,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni_prfm( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni.c index 6993eeeab34d..e801dfa0d626 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x16c4-minmax-avx512vnnigfni.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -90,17 +88,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -125,13 +123,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c index 534158ac6e12..300aecec5425 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd_prfm( c3 = c2; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); + const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -90,17 +88,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m128i va3x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m128i va3x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -135,13 +133,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd.c index 95e522c7d39b..b6f7249b6047 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-sse41-madd.c @@ -61,12 +61,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd( c3 = c2; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); + const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -89,17 +87,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m128i va3x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m128i va3x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -133,13 +131,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c index 0d5297973f8b..57e896d529d6 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd_prfm( c3 = c2; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); + const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -103,17 +101,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m128i va3x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m128i va3x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -148,13 +146,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd_prfm( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd.c index 4270f5fbdfce..cf0fdf5ff010 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x4c8-minmax-ssse3-madd.c @@ -61,12 +61,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd( c3 = c2; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m128i vinput_zero_point0 = _mm_set1_epi32((int) quantization_params[0].zero_point); + const __m128i vinput_zero_point1 = _mm_set1_epi32((int) quantization_params[1].zero_point); + const __m128i vinput_zero_point2 = _mm_set1_epi32((int) quantization_params[2].zero_point); + const __m128i vinput_zero_point3 = _mm_set1_epi32((int) quantization_params[3].zero_point); const __m128 voutput_min = _mm_set1_ps(params->scalar.min); const __m128 voutput_max = _mm_set1_ps(params->scalar.max); const __m128i vmask = _mm_set1_epi8(0x0F); @@ -102,17 +100,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m128i va0x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m128i va0x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m128i va1x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m128i va1x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m128i va2x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m128i va2x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m128i va3x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m128i va3x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m128i vbb01234567x0123 = _mm_load_si128(w); @@ -146,13 +144,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd( } if (k != 0) { - const __m128i va0x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m128i va0x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m128i va1x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m128i va1x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m128i va2x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m128i va2x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m128i va3x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m128i va3x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m128i vbb01234567x0123 = _mm_load_si128(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c index d145ddb0fc6a..e27dead6ef45 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -92,17 +90,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -137,13 +135,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd.c index c9a75fba3d5a..367739dea766 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avx2-madd.c @@ -61,12 +61,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -91,17 +89,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -135,13 +133,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c index 6d204509df6c..ebb185c79a92 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -92,17 +90,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -137,13 +135,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni.c index 065595a910b9..735ab2a0802a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-4x8c8-minmax-avxvnni.c @@ -61,12 +61,10 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -91,17 +89,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -135,13 +133,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd-prfm.c index 1f541ba7cb3b..894a150c3c67 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -98,20 +96,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -141,15 +139,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd_prfm( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd.c index 1df3f84f2129..194d94ba0248 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512skx-madd.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -97,20 +95,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -139,15 +137,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c index ba2b0d0b0ad0..284144b2b5e3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -98,20 +96,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -141,15 +139,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c index e6ed23411636..19c480c7c0ea 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnni.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -97,20 +95,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -139,15 +137,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni-prfm.c index 109b2c31844c..19b229c35df4 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -100,20 +98,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -142,15 +140,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni_prfm( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni.c index c911b0736b41..386cce47f8d9 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c4-minmax-avx512vnnigfni.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -99,20 +97,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -140,15 +138,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd-prfm.c index 79c87896aff5..416edc8ff64c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -103,20 +101,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -156,15 +154,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd.c index 5f374062f150..4ed4478ee3d8 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512skx-madd.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -153,15 +151,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni-prfm.c index 8bf579a7e363..a651a3c6213b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -103,20 +101,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -156,15 +154,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni.c index 0e1b62724155..3922200d9a4f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnni.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -153,15 +151,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c index 8cb77ba76453..84444e8c3413 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -105,20 +103,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -156,15 +154,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni.c index f513b017232a..03c283da0dcd 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x16c8-minmax-avx512vnnigfni.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -104,20 +102,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -153,15 +151,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c index 739f42ad953c..a96bb97a3571 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd.c index bfd5ec705f42..048748b14a4a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx2-madd.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c index 2150f3b3ca20..266999f758eb 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c index 253fb818e712..973b13ec3161 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256skx-madd.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c index 049f211b2307..b84eb64ff923 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni.c index d1ff09133ed1..9bf6ffaf48cc 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c index 7b7e2214e902..16fc897a7d7f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -104,20 +102,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c index 22a5166006a4..a3fd5f1f96ca 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avx256vnnigfni.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -103,20 +101,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c index 897fc9850eab..3bf76b2d6e6b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -102,20 +100,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -154,15 +152,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni.c index 652d5c7b752c..7facb2c33845 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-5x8c8-minmax-avxvnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -152,15 +150,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c index 2aa0bbe4f38e..491587c62145 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd-prfm.c @@ -74,14 +74,12 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -112,23 +110,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -171,17 +169,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd.c index 2e846ee14d47..429447f48685 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avx2-madd.c @@ -73,14 +73,12 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -111,23 +109,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -169,17 +167,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c index 86779896b225..12466395966c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni-prfm.c @@ -74,14 +74,12 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -112,23 +110,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -171,17 +169,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni.c index 76f2d2945752..b40aed705963 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-6x8c8-minmax-avxvnni.c @@ -73,14 +73,12 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -111,23 +109,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -169,17 +167,17 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512amx.c index 976e4cedd291..2b0de86073a8 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[1][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -196,9 +201,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[2].zero_point)); @@ -206,13 +221,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512amx( __m512i vacc4x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[4].zero_point)); __m512i vacc5x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[5].zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc1x0123456789ABCDEF = _mm512_srai_epi32(vacc1x0123456789ABCDEF, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd-prfm.c index e4632fa7b81f..193ffc73a96d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -116,26 +114,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -171,19 +169,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd_prfm( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd.c index 4175029981a1..b1e877642baf 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512skx-madd.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -115,26 +113,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -169,19 +167,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c index 04b14c0beae9..13664ead8118 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -116,26 +114,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -171,19 +169,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c index 60567cce10d7..d9125eb174fc 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnni.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -115,26 +113,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -169,19 +167,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni-prfm.c index 12e8768c052e..8115e877e493 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -118,26 +116,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -172,19 +170,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni_prfm( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni.c index 03b539290336..a972ea713950 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c4-minmax-avx512vnnigfni.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -117,26 +115,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -170,19 +168,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd-prfm.c index f4b4705f20a4..fbd4738d0a9d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -123,26 +121,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -190,19 +188,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd.c index f483d0ed6bf5..ca448a5e193f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512skx-madd.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -187,19 +185,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni-prfm.c index 688c61a58e90..862d2256e3ba 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -123,26 +121,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -190,19 +188,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni.c index ae3d37354dc8..946511a42570 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnni.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -187,19 +185,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c index 3dd2248c6f0c..585a17e1f4ae 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -125,26 +123,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -190,19 +188,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni.c index 71566bb83a50..74d8b928e9df 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x16c8-minmax-avx512vnnigfni.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -124,26 +122,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -187,19 +185,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x32c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x32c4-minmax-avx512amx.c index 649754a65d5a..93217cbe01f6 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,9 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[2][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -250,10 +254,20 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); @@ -268,20 +282,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x32c4__avx512amx( __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[5].zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x64c4-minmax-avx512amx.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x64c4-minmax-avx512amx.c index 6263155c8c0a..c42e3973f49b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,11 +48,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +68,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -358,12 +360,22 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -392,34 +404,35 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x64c4__avx512amx( __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); vacc0x0123456789ABCDEF = _mm512_srai_epi32(vacc0x0123456789ABCDEF, 4); vacc0xGHIJKLMNOPQRSTUV = _mm512_srai_epi32(vacc0xGHIJKLMNOPQRSTUV, 4); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c index d51f82bc3126..79d84a856721 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd.c index 4a697cdc8c35..dc6d887520ca 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx2-madd.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c index 3441bcf936c9..3ddb172ef068 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c index 5b5728bb9cac..486d1cba688b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256skx-madd.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c index 3df4da4f9d06..95965fbf7edf 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni.c index a5e880bf131a..34b1d6520503 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c index a1134da8baa0..819aa032534d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -124,26 +122,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c index 7a21057d748c..7f9cdcbf5320 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avx256vnnigfni.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -123,26 +121,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c index e88497c1703f..4a915fec7edb 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,26 +120,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -188,19 +186,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni.c index 763fd1417a4b..c1f5c723bbac 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-7x8c8-minmax-avxvnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -186,19 +184,19 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd-prfm.c index ef580589b7ee..d2d5da1dcf48 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -125,29 +123,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -186,21 +184,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd_prfm( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd.c index f073c96442e3..3bbe73e5caf5 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512skx-madd.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -124,29 +122,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -184,21 +182,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c index 854f005b702f..d76acb08af99 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -125,29 +123,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -186,21 +184,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c index 0d87b05c18b9..cfc32da455cd 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnni.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -124,29 +122,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -184,21 +182,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni-prfm.c index 7fff7c4116fa..bd78bd11be98 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -127,29 +125,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -187,21 +185,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni_prfm( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni.c index 919a30947a49..72028ba9dfc3 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c4-minmax-avx512vnnigfni.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -126,29 +124,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -185,21 +183,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd-prfm.c index 8c87c8c19d5d..cc192d5206ee 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,29 +131,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -207,21 +205,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd.c index 2d51c0b66b8e..2f369fe520f5 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512skx-madd.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -204,21 +202,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni-prfm.c index ce8849a58208..3cae6efe19ab 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,29 +131,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -207,21 +205,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni.c index 4e8ac5c6c427..a954af21f407 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -204,21 +202,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c index 0c0000c880aa..37618d763d82 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -135,29 +133,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -207,21 +205,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni.c index d095bc452c23..73c3b682d603 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnnigfni.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -134,29 +132,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -204,21 +202,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c index cccac63e55c9..b4a749e16f77 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd.c index 6193f00892cd..a3091cd29d00 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx2-madd.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c index 535228e2735c..f2eae95717b2 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c index 5207522ac504..820e94801d3b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256skx-madd.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c index cf529ec1e5ff..e03b051de7e7 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni.c index c462b2c7ac63..f26ef02c57b9 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c index 2c61a5af677c..df725bfa8e6e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -134,29 +132,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c index 717e199c7052..b0c78b56821b 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avx256vnnigfni.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,29 +131,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c index 965fe712b134..930a64746ddf 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,29 +130,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -205,21 +203,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni.c index e6f19eeadd49..408f85fdd3e4 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x8c8-minmax-avxvnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -203,21 +201,21 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd-prfm.c index e8fa3c74321d..61825e9374db 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -134,32 +132,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -201,23 +199,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd_prfm( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd.c index fdd6212da165..4e236225b25c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512skx-madd.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,32 +131,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -199,23 +197,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni-prfm.c index 6746dbe65a42..dc7dbf164dd5 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -134,32 +132,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -201,23 +199,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni.c index 44eea33bfeea..9bfa27b03e5a 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnni.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,32 +131,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -199,23 +197,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni-prfm.c index 115b6a4dd744..c6b71c50bc74 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -136,32 +134,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -202,23 +200,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni_prfm( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni.c index 04a318eb95e5..82a36260165c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c4-minmax-avx512vnnigfni.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -135,32 +133,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vbb0123456789ABCDEFx01234567 = _mm512_load_si512(w); @@ -200,23 +198,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vbb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd-prfm.c index 52d06187248d..5f6af38c729c 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -143,32 +141,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -224,23 +222,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd.c index bef5a85f5cfc..0635b7860c4f 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512skx-madd.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,32 +140,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -221,23 +219,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni-prfm.c index acd83b39e71f..1d9493c73eec 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -143,32 +141,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -224,23 +222,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni.c index 25704443b435..5f2fd87ad740 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnni.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,32 +140,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -221,23 +219,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c index 0297a517c4eb..4161045f039e 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -145,32 +143,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -224,23 +222,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni.c index 4cac21b194ff..ad1531827fc0 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x16c8-minmax-avx512vnnigfni.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -144,32 +142,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); @@ -221,23 +219,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vbb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c index ccf656d89363..147986ba79d0 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,32 +140,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -222,23 +220,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c index 01eda66525c2..e2104c6218c4 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256skx-madd.c @@ -91,17 +91,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,32 +139,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -220,23 +218,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c index 35ad8b46d1b5..e80ee482a2c2 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -142,32 +140,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -222,23 +220,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni.c index 58deb7039e4e..b584d0058126 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnni.c @@ -91,17 +91,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,32 +139,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -220,23 +218,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c index f8ca6b89b0d9..0c44ef756c3d 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -144,32 +142,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -222,23 +220,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c index 26e08b4381b3..fc71a01032d1 100644 --- a/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c +++ b/src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-9x8c8-minmax-avx256vnnigfni.c @@ -91,17 +91,15 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -143,32 +141,32 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); @@ -220,23 +218,23 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vbb01234567x01234567 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..b0225fc5f4f1 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,375 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1104 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1192] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + # Load quantization_params pointer from stack + mov r11, [rsp + 1192] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + vmovups [rbp], zmm20 + vmovups [r8], zmm21 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + +return: + add rsp, 1104 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni-prfm.c index 62508d9437e8..6573a97d929b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,35 +139,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -213,25 +211,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni.c index b83a2e4f87e8..9587a8b56884 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -140,35 +138,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -210,25 +208,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni( vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, vacc1x9x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c index 0790c917afb4..762e3828a3ea 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c @@ -99,18 +99,16 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -151,35 +149,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -237,25 +235,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni.c index 3aa696ff5e42..27de18b9e31b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -150,35 +148,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -232,25 +230,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..f0cd38a4acca --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,471 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1104 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1192] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm28, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm29, zmm7, ZMMWORD PTR [rsp + 912] + vpmulld zmm30, zmm7, ZMMWORD PTR [rsp + 976] + vpmulld zmm4, zmm7, ZMMWORD PTR [rsp + 1040] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm28, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm29, zmm2, zmm7 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpdpbusd zmm30, zmm2, zmm7 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + vpdpbusd zmm4, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + vcvtdq2ps zmm30, zmm30 + vcvtdq2ps zmm4, zmm4 + # Load quantization_params pointer from stack + mov r11, [rsp + 1192] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 4]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 12]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 20]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 28]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 36]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 44]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 52]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 60]{1to16} + vmulps zmm30, zmm30, DWORD PTR [r11 + 68]{1to16} + vmulps zmm4, zmm4, DWORD PTR [r11 + 76]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + vfmadd213ps zmm28, zmm11, zmm7 + vfmadd213ps zmm29, zmm11, zmm7 + vfmadd213ps zmm30, zmm11, zmm7 + vfmadd213ps zmm4, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm4, zmm1, zmm4 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm4, zmm0, zmm4 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm22 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm23 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm24 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm25 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm26 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm27 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm28 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm29 + vmovups [rbp], zmm20 + vmovups [rbp + 64], zmm30 + vmovups [r8], zmm21 + vmovups [r8 + 64], zmm4 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm30 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm4 + +return: + add rsp, 1104 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c index adf226a86115..de64588df83f 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni-prfm.c @@ -98,18 +98,16 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -150,35 +148,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -234,25 +232,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni.c index 78d4caa82097..13b6f7e2bf21 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x8c8-minmax-avx256vnni.c @@ -97,18 +97,16 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -149,35 +147,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -231,25 +229,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..638c6dc56714 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,404 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1168 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1256] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + mov edi, [r11 + 80] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1104], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + vpmulld zmm22, zmm6, ZMMWORD PTR [rsp + 1104] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + vpbroadcastd zmm2, [rdi + r11] + vpdpbusd zmm22, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + # Load quantization_params pointer from stack + mov r11, [rsp + 1256] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 84]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + vfmadd213ps zmm22, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + vmovups [rbp], zmm20 + vmovups [r8], zmm21 + vmovups [rdi], zmm22 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + add rdi, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + vmovups ZMMWORD PTR [rdi]{k1}, zmm22 + +return: + add rsp, 1168 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..aad8d973fe7a --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,509 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1168 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 272], rax + mov [rsp - 280], r13 + + # Clamp a & c pointers if mr <= 10 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 288], rsi + mov [rsp - 296], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1256] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + mov edi, [r11 + 72] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1040], zmm6 + mov edi, [r11 + 80] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 1104], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + mov r8, [rsp - 272] + mov rdi, [rsp - 288] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm6, ZMMWORD PTR [rsp + 1040] + vpmulld zmm22, zmm6, ZMMWORD PTR [rsp + 1104] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm28, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm29, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm30, zmm7, ZMMWORD PTR [rsp + 912] + vpmulld zmm4, zmm7, ZMMWORD PTR [rsp + 976] + vpmulld zmm8, zmm7, ZMMWORD PTR [rsp + 1040] + vpmulld zmm9, zmm7, ZMMWORD PTR [rsp + 1104] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm28, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm29, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm30, zmm2, zmm7 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpdpbusd zmm4, zmm2, zmm7 + vpbroadcastd zmm2, [r8 + r11] + vpdpbusd zmm21, zmm2, zmm6 + vpdpbusd zmm8, zmm2, zmm7 + vpbroadcastd zmm2, [rdi + r11] + vpdpbusd zmm22, zmm2, zmm6 + vpdpbusd zmm9, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + vcvtdq2ps zmm30, zmm30 + vcvtdq2ps zmm4, zmm4 + vcvtdq2ps zmm8, zmm8 + vcvtdq2ps zmm9, zmm9 + # Load quantization_params pointer from stack + mov r11, [rsp + 1256] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 76]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 84]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 4]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 12]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 20]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 28]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 36]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 44]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 52]{1to16} + vmulps zmm30, zmm30, DWORD PTR [r11 + 60]{1to16} + vmulps zmm4, zmm4, DWORD PTR [r11 + 68]{1to16} + vmulps zmm8, zmm8, DWORD PTR [r11 + 76]{1to16} + vmulps zmm9, zmm9, DWORD PTR [r11 + 84]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm10, zmm6 + vfmadd213ps zmm22, zmm10, zmm6 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + vfmadd213ps zmm28, zmm11, zmm7 + vfmadd213ps zmm29, zmm11, zmm7 + vfmadd213ps zmm30, zmm11, zmm7 + vfmadd213ps zmm4, zmm11, zmm7 + vfmadd213ps zmm8, zmm11, zmm7 + vfmadd213ps zmm9, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm4, zmm1, zmm4 + vminps zmm8, zmm1, zmm8 + vminps zmm9, zmm1, zmm9 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm4, zmm0, zmm4 + vmaxps zmm8, zmm0, zmm8 + vmaxps zmm9, zmm0, zmm9 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + mov r8, [rsp - 280] + mov rdi, [rsp - 296] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm23 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm24 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm25 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm26 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm27 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm28 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm29 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm30 + vmovups [rbp], zmm20 + vmovups [rbp + 64], zmm4 + vmovups [r8], zmm21 + vmovups [r8 + 64], zmm8 + vmovups [rdi], zmm22 + vmovups [rdi + 64], zmm9 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + add rdi, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + mov [rsp - 280], r8 + mov [rsp - 296], rdi + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm29 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm30 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm4 + vmovups ZMMWORD PTR [r8]{k1}, zmm21 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm8 + vmovups ZMMWORD PTR [rdi]{k1}, zmm22 + vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm9 + +return: + add rsp, 1168 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni-prfm.c index 17813e5b0504..fa00de409bb7 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -159,41 +157,41 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -243,29 +241,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni.c index 974b77276246..548bdb7c47a7 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -158,41 +156,41 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -240,29 +238,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni( vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, vacc1x11x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni-prfm.c index b4f2ae8da0ac..b4b343f9a84a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni-prfm.c @@ -111,20 +111,18 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -171,41 +169,41 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -271,29 +269,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni.c index 5cc6b537eb90..5e2a7df15a88 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c8-minmax-avx512vnni.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -170,41 +168,41 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -266,29 +264,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c index bb6fe86168c7..0a803f73d4b4 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni-prfm.c @@ -110,20 +110,18 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -170,41 +168,41 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -268,29 +266,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni.c index 9e821f37f98a..84058523993a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x8c8-minmax-avx256vnni.c @@ -109,20 +109,18 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -169,41 +167,41 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -265,29 +263,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni-prfm.c index 9f84c4c4d390..3376e11ad8a7 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -177,47 +175,47 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -273,33 +271,33 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni.c index 2275c3f400df..45778478edc3 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c4-minmax-avx512vnni.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -176,47 +174,47 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -270,33 +268,33 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni( vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, vacc1x13x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni-prfm.c index bf50f9ce68bb..8d72abb85074 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni-prfm.c @@ -123,22 +123,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -191,47 +189,47 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -305,33 +303,33 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni.c index be9d97f00efe..32fbb6f5218e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x16c8-minmax-avx512vnni.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); + const __m512i vinput_zero_point9 = _mm512_set1_epi32((int) quantization_params[9].zero_point); + const __m512i vinput_zero_point10 = _mm512_set1_epi32((int) quantization_params[10].zero_point); + const __m512i vinput_zero_point11 = _mm512_set1_epi32((int) quantization_params[11].zero_point); + const __m512i vinput_zero_point12 = _mm512_set1_epi32((int) quantization_params[12].zero_point); + const __m512i vinput_zero_point13 = _mm512_set1_epi32((int) quantization_params[13].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -190,47 +188,47 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -300,33 +298,33 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c index 1bff2a8dd929..54a0ecded59e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni-prfm.c @@ -122,22 +122,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -190,47 +188,47 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -302,33 +300,33 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni.c index 5a502528c777..3184e567fedb 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-14x8c8-minmax-avx256vnni.c @@ -121,22 +121,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); - const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point + 128); - const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point + 128); - const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point + 128); - const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point + 128); - const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); + const __m256i vinput_zero_point9 = _mm256_set1_epi32((int) quantization_params[9].zero_point); + const __m256i vinput_zero_point10 = _mm256_set1_epi32((int) quantization_params[10].zero_point); + const __m256i vinput_zero_point11 = _mm256_set1_epi32((int) quantization_params[11].zero_point); + const __m256i vinput_zero_point12 = _mm256_set1_epi32((int) quantization_params[12].zero_point); + const __m256i vinput_zero_point13 = _mm256_set1_epi32((int) quantization_params[13].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -189,47 +187,47 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -299,33 +297,33 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx-prfm.c index 012594b46295..1620e4557ab1 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,7 +49,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +68,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -213,9 +218,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[2].zero_point)); @@ -232,22 +247,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm( __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[13].zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx.c index f33282881cee..7083cc08ee6b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,7 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -180,9 +185,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[2].zero_point)); @@ -199,22 +214,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx( __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[13].zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx-prfm.c index 340e0836d6fe..2c45634f26d0 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,8 +49,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +68,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -236,10 +240,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); @@ -272,38 +286,39 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm( __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx.c index 0af55ba885cf..a310d7fe16ab 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -187,10 +191,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); @@ -223,38 +237,39 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx( __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[14].zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c index d6f09892bacb..065465114c9d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,10 +49,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +68,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -282,12 +284,22 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -352,70 +364,71 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx.c index 8993422cd7a1..8a5503c070a6 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-16x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -201,12 +203,22 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -271,70 +283,71 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[15].zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[15].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..41b1fcb9ef1f --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,135 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldr q10, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v3.4s, v10.s[0] + mul v14.4s, v4.4s, v10.s[0] + mul v15.4s, v5.4s, v10.s[0] + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v7.16b, v2.4b[0] + sdot v14.4s, v8.16b, v2.4b[0] + sdot v15.4s, v9.16b, v2.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[1] + fmul v14.4s, v14.4s, v10.s[1] + fmul v15.4s, v15.4s, v10.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v3.4s + fmul v14.4s, v14.4s, v4.4s + fmul v15.4s, v15.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v7.4s + fadd v14.4s, v14.4s, v8.4s + fadd v15.4s, v15.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q13, [x6], 32 + stp q14, q15, [x6], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q13, [x6], 32 + mov v12.16b, v14.16b + mov v13.16b, v15.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + mov v12.16b, v13.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..6d8dc3d5e2de --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,101 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 528 + + # Load quantization params pointer from stack + mov r11, [rsp + 616] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + # Load quantization_params pointer from stack + mov r11, [rsp + 616] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vmaxps zmm5, zmm0, zmm5 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + add r10, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + +return: + add rsp, 528 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512amx.c index 500acf3eb5f3..cc5263ab38ae 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,7 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; + __attribute__((aligned(64))) int32_t res[1][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -120,11 +125,22 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c index 2e81ca2bff63..e517dc1c968d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -60,8 +58,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -78,7 +76,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c index 5832f8f6a116..cad7dbf2d4a9 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -59,8 +57,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -75,7 +73,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni( vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, vacc1x0x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c index 6b8ee70f9dfe..04ef5f6faa8f 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c @@ -45,9 +45,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -63,8 +61,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -86,7 +84,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni.c index 8e03dd8c4d24..50fbd5b4be9e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni( const int8_t* a0 = a; float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -81,7 +79,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..40f098232ed0 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,116 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 528 + + # Load quantization params pointer from stack + mov r11, [rsp + 616] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm7, ZMMWORD PTR [rsp + 464] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm12, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + # Load quantization_params pointer from stack + mov r11, [rsp + 616] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 4]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm12 + add r10, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + +return: + add rsp, 528 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32c4-minmax-avx512amx.c index 3ff94cd8ac22..50e43fa912f0 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; + __attribute__((aligned(64))) int32_t res[2][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -127,14 +131,25 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..864746630bd4 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,147 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 528 + + # Load quantization params pointer from stack + mov r11, [rsp + 616] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm14, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm15, zmm9, ZMMWORD PTR [rsp + 464] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm12, zmm2, zmm7 + vpdpbusd zmm14, zmm2, zmm8 + vpdpbusd zmm15, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + # Load quantization_params pointer from stack + mov r11, [rsp + 616] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 4]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 4]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 4]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm11, zmm7 + vfmadd213ps zmm14, zmm2, zmm8 + vfmadd213ps zmm15, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm12 + vmovups [r10 + 128], zmm14 + vmovups [r10 + 192], zmm15 + add r10, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm14 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm15 + +return: + add rsp, 528 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64c4-minmax-avx512amx.c index 1cc80e052903..e3a192175440 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -141,20 +143,31 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[0].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..91aea6ddc518 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,107 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldr q10, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v3.4s, v10.s[0] + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v7.16b, v2.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q13, [x6], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + mov v12.16b, v13.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + dup d12, v12.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u2-acc2.c index 899fb6d2cf77..d94410198fa4 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u2-acc2.c @@ -46,9 +46,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u2_acc2( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -65,9 +63,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u2_acc2( __m256i va0x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -82,8 +77,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u2_acc2( __m256i va0x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u4-acc4.c index a106c84b4c72..1bc1f161320e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c4-minmax-avxvnni-u4-acc4.c @@ -46,9 +46,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u4_acc4( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -69,11 +67,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u4_acc4( __m256i va0x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a0 + 12)); a0 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -94,8 +87,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u4_acc4( __m256i va0x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c index b8d1359de839..cfb3e61f2f3d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -83,7 +81,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni.c index 617cdc865fba..8544a72ed18a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avx256vnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -61,8 +59,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -80,7 +78,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c index eee895b50ddb..106b52ab06da 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni-prfm.c @@ -44,9 +44,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -62,8 +60,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -83,7 +81,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni.c index e0f1d0c1414a..ba0eb27b7cca 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8c8-minmax-avxvnni.c @@ -43,9 +43,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni( const int8_t* a0 = a; float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -61,8 +59,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -80,7 +78,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..68d29de372d0 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,185 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v3.4s, v10.s[0] + mul v15.4s, v3.4s, v10.s[2] + mul v16.4s, v4.4s, v10.s[0] + mul v17.4s, v4.4s, v10.s[2] + mul v18.4s, v5.4s, v10.s[0] + mul v19.4s, v5.4s, v10.s[2] + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v7.16b, v2.4b[0] + sdot v15.4s, v7.16b, v3.4b[0] + sdot v16.4s, v8.16b, v2.4b[0] + sdot v17.4s, v8.16b, v3.4b[0] + sdot v18.4s, v9.16b, v2.4b[0] + sdot v19.4s, v9.16b, v3.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v10.s[1] + fmul v15.4s, v15.4s, v10.s[3] + fmul v16.4s, v16.4s, v10.s[1] + fmul v17.4s, v17.4s, v10.s[3] + fmul v18.4s, v18.4s, v10.s[1] + fmul v19.4s, v19.4s, v10.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v3.4s + fmul v15.4s, v15.4s, v3.4s + fmul v16.4s, v16.4s, v4.4s + fmul v17.4s, v17.4s, v4.4s + fmul v18.4s, v18.4s, v5.4s + fmul v19.4s, v19.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v7.4s + fadd v15.4s, v15.4s, v7.4s + fadd v16.4s, v16.4s, v8.4s + fadd v17.4s, v17.4s, v8.4s + fadd v18.4s, v18.4s, v9.4s + fadd v19.4s, v19.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q14, [x6], 32 + stp q16, q18, [x6], 32 + stp q13, q15, [x13], 32 + stp q17, q19, [x13], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q14, [x6], 32 + stp q13, q15, [x13], 32 + mov v12.16b, v16.16b + mov v14.16b, v18.16b + mov v13.16b, v17.16b + mov v15.16b, v19.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + mov v12.16b, v14.16b + mov v13.16b, v15.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..34b63a18d390 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,124 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 592 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 680] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + # Load quantization_params pointer from stack + mov r11, [rsp + 680] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + add r10, 64 + add r13, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + +return: + add rsp, 592 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..955f2135499f --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,148 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 592 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 680] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 528] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm14, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + # Load quantization_params pointer from stack + mov r11, [rsp + 680] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 4]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 12]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm11, zmm7 + vfmadd213ps zmm15, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm14 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm15 + add r10, 128 + add r13, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + +return: + add rsp, 592 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..9dc74973e555 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,197 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 592 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 680] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm16, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm17, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm18, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm19, zmm9, ZMMWORD PTR [rsp + 528] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm14, zmm2, zmm7 + vpdpbusd zmm16, zmm2, zmm8 + vpdpbusd zmm18, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + vpdpbusd zmm17, zmm2, zmm8 + vpdpbusd zmm19, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + # Load quantization_params pointer from stack + mov r11, [rsp + 680] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 4]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 12]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 4]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 12]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 4]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 12]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm11, zmm7 + vfmadd213ps zmm15, zmm11, zmm7 + vfmadd213ps zmm16, zmm2, zmm8 + vfmadd213ps zmm17, zmm2, zmm8 + vfmadd213ps zmm18, zmm3, zmm9 + vfmadd213ps zmm19, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm14 + vmovups [r10 + 128], zmm16 + vmovups [r10 + 192], zmm18 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm15 + vmovups [r13 + 128], zmm17 + vmovups [r13 + 192], zmm19 + add r10, 256 + add r13, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm16 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm18 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm17 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm19 + +return: + add rsp, 592 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..3e2a668453b8 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,137 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x13, x6, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v3.4s, v10.s[0] + mul v15.4s, v3.4s, v10.s[2] + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v7.16b, v2.4b[0] + sdot v15.4s, v7.16b, v3.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v10.s[1] + fmul v15.4s, v15.4s, v10.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v3.4s + fmul v15.4s, v15.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v7.4s + fadd v15.4s, v15.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q14, [x6], 32 + stp q13, q15, [x13], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + mov v12.16b, v14.16b + mov v13.16b, v15.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u2-acc2.c index 49582b22941e..6d002746e06c 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u2-acc2.c @@ -52,10 +52,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u2_acc2( c1 = c0; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -77,11 +75,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u2_acc2( __m256i va1x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -101,9 +94,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u2_acc2( __m256i va1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u4-acc4.c index 5add4e655313..cbea6640ab38 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c4-minmax-avxvnni-u4-acc4.c @@ -52,10 +52,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u4_acc4( c1 = c0; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -85,15 +83,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u4_acc4( __m256i va1x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a1 + 12)); a1 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -123,9 +112,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u4_acc4( __m256i va1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c index 5ad722e5561f..2e00c6440975 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni-prfm.c @@ -50,10 +50,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -74,11 +72,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -102,9 +100,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni.c index 16049bc010b1..ee00616ef176 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8c8-minmax-avxvnni.c @@ -49,10 +49,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -73,11 +71,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -99,9 +97,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..4720444c9cd7 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,233 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldr q10, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v3.4s, v10.s[0] + mul v16.4s, v3.4s, v10.s[2] + mul v17.4s, v3.4s, v11.s[0] + mul v18.4s, v4.4s, v10.s[0] + mul v19.4s, v4.4s, v10.s[2] + mul v20.4s, v4.4s, v11.s[0] + mul v21.4s, v5.4s, v10.s[0] + mul v22.4s, v5.4s, v10.s[2] + mul v23.4s, v5.4s, v11.s[0] + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v7.16b, v2.4b[0] + sdot v16.4s, v7.16b, v3.4b[0] + sdot v17.4s, v7.16b, v4.4b[0] + sdot v18.4s, v8.16b, v2.4b[0] + sdot v19.4s, v8.16b, v3.4b[0] + sdot v20.4s, v8.16b, v4.4b[0] + sdot v21.4s, v9.16b, v2.4b[0] + sdot v22.4s, v9.16b, v3.4b[0] + sdot v23.4s, v9.16b, v4.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v10.s[1] + fmul v16.4s, v16.4s, v10.s[3] + fmul v17.4s, v17.4s, v11.s[1] + fmul v18.4s, v18.4s, v10.s[1] + fmul v19.4s, v19.4s, v10.s[3] + fmul v20.4s, v20.4s, v11.s[1] + fmul v21.4s, v21.4s, v10.s[1] + fmul v22.4s, v22.4s, v10.s[3] + fmul v23.4s, v23.4s, v11.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v3.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + fmul v18.4s, v18.4s, v4.4s + fmul v19.4s, v19.4s, v4.4s + fmul v20.4s, v20.4s, v4.4s + fmul v21.4s, v21.4s, v5.4s + fmul v22.4s, v22.4s, v5.4s + fmul v23.4s, v23.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v7.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + fadd v18.4s, v18.4s, v8.4s + fadd v19.4s, v19.4s, v8.4s + fadd v20.4s, v20.4s, v8.4s + fadd v21.4s, v21.4s, v9.4s + fadd v22.4s, v22.4s, v9.4s + fadd v23.4s, v23.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q15, [x6], 32 + stp q18, q21, [x6], 32 + stp q13, q16, [x13], 32 + stp q19, q22, [x13], 32 + stp q14, q17, [x14], 32 + stp q20, q23, [x14], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q15, [x6], 32 + stp q13, q16, [x13], 32 + stp q14, q17, [x14], 32 + mov v12.16b, v18.16b + mov v15.16b, v21.16b + mov v13.16b, v19.16b + mov v16.16b, v22.16b + mov v14.16b, v20.16b + mov v17.16b, v23.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + mov v12.16b, v15.16b + mov v13.16b, v16.16b + mov v14.16b, v17.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..7d65da61eb9f --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,147 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 656 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 744] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + # Load quantization_params pointer from stack + mov r11, [rsp + 744] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + vmovups [rbx], zmm14 + add r10, 64 + add r13, 64 + add rbx, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + +return: + add rsp, 656 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..836ea2aaa287 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,180 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 656 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 744] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 592] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + # Load quantization_params pointer from stack + mov r11, [rsp + 744] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 4]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 12]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 20]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm11, zmm7 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm15 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm16 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm17 + add r10, 128 + add r13, 128 + add rbx, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + +return: + add rsp, 656 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..0382f92d2209 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,247 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 656 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 744] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm18, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm19, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm20, zmm8, ZMMWORD PTR [rsp + 592] + vpmulld zmm21, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm22, zmm9, ZMMWORD PTR [rsp + 528] + vpmulld zmm23, zmm9, ZMMWORD PTR [rsp + 592] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm15, zmm2, zmm7 + vpdpbusd zmm18, zmm2, zmm8 + vpdpbusd zmm21, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpdpbusd zmm19, zmm2, zmm8 + vpdpbusd zmm22, zmm2, zmm9 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpdpbusd zmm20, zmm2, zmm8 + vpdpbusd zmm23, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + # Load quantization_params pointer from stack + mov r11, [rsp + 744] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 4]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 12]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 20]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 4]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 12]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 20]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 4]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 12]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 20]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm11, zmm7 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm2, zmm8 + vfmadd213ps zmm19, zmm2, zmm8 + vfmadd213ps zmm20, zmm2, zmm8 + vfmadd213ps zmm21, zmm3, zmm9 + vfmadd213ps zmm22, zmm3, zmm9 + vfmadd213ps zmm23, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm15 + vmovups [r10 + 128], zmm18 + vmovups [r10 + 192], zmm21 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm16 + vmovups [r13 + 128], zmm19 + vmovups [r13 + 192], zmm22 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm17 + vmovups [rbx + 128], zmm20 + vmovups [rbx + 192], zmm23 + add r10, 256 + add r13, 256 + add rbx, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm18 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm21 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm19 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm22 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm20 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm23 + +return: + add rsp, 656 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..f947587a63e8 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,165 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x13, x6, x7 + add x14, x13, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldr q10, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v3.4s, v10.s[0] + mul v16.4s, v3.4s, v10.s[2] + mul v17.4s, v3.4s, v11.s[0] + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v7.16b, v2.4b[0] + sdot v16.4s, v7.16b, v3.4b[0] + sdot v17.4s, v7.16b, v4.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v10.s[1] + fmul v16.4s, v16.4s, v10.s[3] + fmul v17.4s, v17.4s, v11.s[1] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v3.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v7.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q15, [x6], 32 + stp q13, q16, [x13], 32 + stp q14, q17, [x14], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + mov v12.16b, v15.16b + mov v13.16b, v16.16b + mov v14.16b, v17.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u2-acc2.c index 4ba1e167e0ad..402d1095825b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u2-acc2.c @@ -58,11 +58,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u2_acc2( c2 = c1; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -89,13 +87,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u2_acc2( __m256i va2x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -120,10 +111,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u2_acc2( __m256i va2x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u4-acc4.c index b4f2a32fa89b..14a09d8eb248 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c4-minmax-avxvnni-u4-acc4.c @@ -58,11 +58,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u4_acc4( c2 = c1; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,19 +99,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u4_acc4( __m256i va2x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a2 + 12)); a2 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va2x2x0123 = _mm256_xor_si256(va2x2x0123, vsign_mask); - va2x3x0123 = _mm256_xor_si256(va2x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -152,10 +137,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u4_acc4( __m256i va2x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c index 03c0813449f3..558022a222ea 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni-prfm.c @@ -56,11 +56,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -80,14 +78,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -115,11 +113,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni.c index 1eb18cce308c..4168a39123f5 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8c8-minmax-avxvnni.c @@ -55,11 +55,9 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -79,14 +77,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -112,11 +110,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..8663d8c685cc --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,281 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v2.4s, v11.s[2] + mul v16.4s, v3.4s, v10.s[0] + mul v17.4s, v3.4s, v10.s[2] + mul v18.4s, v3.4s, v11.s[0] + mul v19.4s, v3.4s, v11.s[2] + mul v20.4s, v4.4s, v10.s[0] + mul v21.4s, v4.4s, v10.s[2] + mul v22.4s, v4.4s, v11.s[0] + mul v23.4s, v4.4s, v11.s[2] + mul v24.4s, v5.4s, v10.s[0] + mul v25.4s, v5.4s, v10.s[2] + mul v26.4s, v5.4s, v11.s[0] + mul v27.4s, v5.4s, v11.s[2] + add x5, x5, 64 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldr d5, [x11, x20] + ldp q6, q7, [x5], 32 + ldp q8, q9, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v6.16b, v5.4b[0] + sdot v16.4s, v7.16b, v2.4b[0] + sdot v17.4s, v7.16b, v3.4b[0] + sdot v18.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v5.4b[0] + sdot v20.4s, v8.16b, v2.4b[0] + sdot v21.4s, v8.16b, v3.4b[0] + sdot v22.4s, v8.16b, v4.4b[0] + sdot v23.4s, v8.16b, v5.4b[0] + sdot v24.4s, v9.16b, v2.4b[0] + sdot v25.4s, v9.16b, v3.4b[0] + sdot v26.4s, v9.16b, v4.4b[0] + sdot v27.4s, v9.16b, v5.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v11.s[3] + fmul v16.4s, v16.4s, v10.s[1] + fmul v17.4s, v17.4s, v10.s[3] + fmul v18.4s, v18.4s, v11.s[1] + fmul v19.4s, v19.4s, v11.s[3] + fmul v20.4s, v20.4s, v10.s[1] + fmul v21.4s, v21.4s, v10.s[3] + fmul v22.4s, v22.4s, v11.s[1] + fmul v23.4s, v23.4s, v11.s[3] + fmul v24.4s, v24.4s, v10.s[1] + fmul v25.4s, v25.4s, v10.s[3] + fmul v26.4s, v26.4s, v11.s[1] + fmul v27.4s, v27.4s, v11.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + ldp q4, q5, [x5, 32] + add x5, x5, 64 + # Load biases. + ldp q6, q7, [x5, 0] + ldp q8, q9, [x5, 32] + add x5, x5, 64 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v2.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + fmul v18.4s, v18.4s, v3.4s + fmul v19.4s, v19.4s, v3.4s + fmul v20.4s, v20.4s, v4.4s + fmul v21.4s, v21.4s, v4.4s + fmul v22.4s, v22.4s, v4.4s + fmul v23.4s, v23.4s, v4.4s + fmul v24.4s, v24.4s, v5.4s + fmul v25.4s, v25.4s, v5.4s + fmul v26.4s, v26.4s, v5.4s + fmul v27.4s, v27.4s, v5.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v6.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + fadd v18.4s, v18.4s, v7.4s + fadd v19.4s, v19.4s, v7.4s + fadd v20.4s, v20.4s, v8.4s + fadd v21.4s, v21.4s, v8.4s + fadd v22.4s, v22.4s, v8.4s + fadd v23.4s, v23.4s, v8.4s + fadd v24.4s, v24.4s, v9.4s + fadd v25.4s, v25.4s, v9.4s + fadd v26.4s, v26.4s, v9.4s + fadd v27.4s, v27.4s, v9.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmin v20.4s, v1.4s, v20.4s + fmin v21.4s, v1.4s, v21.4s + fmin v22.4s, v1.4s, v22.4s + fmin v23.4s, v1.4s, v23.4s + fmin v24.4s, v1.4s, v24.4s + fmin v25.4s, v1.4s, v25.4s + fmin v26.4s, v1.4s, v26.4s + fmin v27.4s, v1.4s, v27.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + fmax v20.4s, v0.4s, v20.4s + fmax v21.4s, v0.4s, v21.4s + fmax v22.4s, v0.4s, v22.4s + fmax v23.4s, v0.4s, v23.4s + fmax v24.4s, v0.4s, v24.4s + fmax v25.4s, v0.4s, v25.4s + fmax v26.4s, v0.4s, v26.4s + fmax v27.4s, v0.4s, v27.4s + + # Check whether full or partial store. + cmp x1, 16 + b.lo tail_8 + stp q12, q16, [x6], 32 + stp q20, q24, [x6], 32 + stp q13, q17, [x13], 32 + stp q21, q25, [x13], 32 + stp q14, q18, [x14], 32 + stp q22, q26, [x14], 32 + stp q15, q19, [x15], 32 + stp q23, q27, [x15], 32 + + sub x1, x1, 16 + b.ne outer_loop + b return + +tail_8: + tbz x1, 3, tail_4 + stp q12, q16, [x6], 32 + stp q13, q17, [x13], 32 + stp q14, q18, [x14], 32 + stp q15, q19, [x15], 32 + mov v12.16b, v20.16b + mov v16.16b, v24.16b + mov v13.16b, v21.16b + mov v17.16b, v25.16b + mov v14.16b, v22.16b + mov v18.16b, v26.16b + mov v15.16b, v23.16b + mov v19.16b, v27.16b + + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + str q15, [x15], 16 + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + mov v15.16b, v19.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + str d15, [x15], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + str s15, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..02e801e11dcc --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,170 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 720 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Load quantization params pointer from stack + mov r11, [rsp + 808] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + # Load quantization_params pointer from stack + mov r11, [rsp + 808] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + vmovups [rbx], zmm14 + vmovups [rbp], zmm15 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + +return: + add rsp, 720 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c index 434f06304d48..3d4aca9b25c3 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c @@ -63,12 +63,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -87,17 +85,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -123,13 +121,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c index 386052403d06..c7ddef93e851 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -86,17 +84,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -120,13 +118,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni( vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, vacc1x3x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..3e7674a87fea --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,212 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 720 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Load quantization params pointer from stack + mov r11, [rsp + 808] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 656] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + # Load quantization_params pointer from stack + mov r11, [rsp + 808] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 4]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 12]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 20]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 28]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm16 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm17 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm18 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm19 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + +return: + add rsp, 720 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..09d697c5aebb --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,297 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 720 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Load quantization params pointer from stack + mov r11, [rsp + 808] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm20, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm21, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm22, zmm8, ZMMWORD PTR [rsp + 592] + vpmulld zmm23, zmm8, ZMMWORD PTR [rsp + 656] + vpmulld zmm24, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm25, zmm9, ZMMWORD PTR [rsp + 528] + vpmulld zmm26, zmm9, ZMMWORD PTR [rsp + 592] + vpmulld zmm27, zmm9, ZMMWORD PTR [rsp + 656] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm16, zmm2, zmm7 + vpdpbusd zmm20, zmm2, zmm8 + vpdpbusd zmm24, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpdpbusd zmm21, zmm2, zmm8 + vpdpbusd zmm25, zmm2, zmm9 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpdpbusd zmm22, zmm2, zmm8 + vpdpbusd zmm26, zmm2, zmm9 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpdpbusd zmm23, zmm2, zmm8 + vpdpbusd zmm27, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + # Load quantization_params pointer from stack + mov r11, [rsp + 808] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 4]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 12]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 20]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 28]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 4]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 12]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 20]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 28]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 4]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 12]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 20]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 28]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm11, zmm7 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm2, zmm8 + vfmadd213ps zmm21, zmm2, zmm8 + vfmadd213ps zmm22, zmm2, zmm8 + vfmadd213ps zmm23, zmm2, zmm8 + vfmadd213ps zmm24, zmm3, zmm9 + vfmadd213ps zmm25, zmm3, zmm9 + vfmadd213ps zmm26, zmm3, zmm9 + vfmadd213ps zmm27, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm16 + vmovups [r10 + 128], zmm20 + vmovups [r10 + 192], zmm24 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm17 + vmovups [r13 + 128], zmm21 + vmovups [r13 + 192], zmm25 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm18 + vmovups [rbx + 128], zmm22 + vmovups [rbx + 192], zmm26 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm19 + vmovups [rbp + 128], zmm23 + vmovups [rbp + 192], zmm27 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm20 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm24 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm21 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm25 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm26 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm23 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm27 + +return: + add rsp, 720 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S new file mode 100644 index 000000000000..de28ec7033ab --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S @@ -0,0 +1,193 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane + + # Free up GP registers. + stp x19, x20, [sp, -48] + stp x21, x22, [sp, -32] + stp x23, x24, [sp, -16] + + # Preserve callee saved q8-q15 registers. + stp q8, q9, [sp, -176] + stp q10, q11, [sp, -144] + stp q12, q13, [sp, -112] + stp q14, q15, [sp, -80] + + # Load params. + ldr x13, [sp, 8] + + # Load min/max values. + ld2r {v0.4s, v1.4s}, [x13] + ldr x24, [sp, 16] + # Round kc up to channels. + add x2, x2, #3 + and x2, x2, #0xFFFFFFFFFFFFFFFC + + # Setup and alias a & c pointers. + add x9, x3, x4 + add x10, x9, x4 + add x11, x10, x4 + add x13, x6, x7 + add x14, x13, x7 + add x15, x14, x7 + + cmp x0, 2 + csel x9, x3, x9, LO + csel x13, x6, x13, LO + csel x10, x9, x10, LS + csel x14, x13, x14, LS + + cmp x0, 4 + csel x11, x10, x11, LO + csel x15, x14, x15, LO + +outer_loop: + # Zero k counter. + eor x20, x20, x20 + # Initialize accumulators with k_sum * input zero point. + ldp q10, q11, [x24] + ldp q2, q3, [x5, 0] + mul v12.4s, v2.4s, v10.s[0] + mul v13.4s, v2.4s, v10.s[2] + mul v14.4s, v2.4s, v11.s[0] + mul v15.4s, v2.4s, v11.s[2] + mul v16.4s, v3.4s, v10.s[0] + mul v17.4s, v3.4s, v10.s[2] + mul v18.4s, v3.4s, v11.s[0] + mul v19.4s, v3.4s, v11.s[2] + add x5, x5, 32 + +inner_loop: + ldr d2, [x3, x20] + ldr d3, [x9, x20] + ldr d4, [x10, x20] + ldr d5, [x11, x20] + ldp q6, q7, [x5], 32 + sdot v12.4s, v6.16b, v2.4b[0] + sdot v13.4s, v6.16b, v3.4b[0] + sdot v14.4s, v6.16b, v4.4b[0] + sdot v15.4s, v6.16b, v5.4b[0] + sdot v16.4s, v7.16b, v2.4b[0] + sdot v17.4s, v7.16b, v3.4b[0] + sdot v18.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v5.4b[0] + add x20, x20, 4 + cmp x2, x20 + bne inner_loop + + # Convert from int32 to float. + scvtf v12.4s, v12.4s + scvtf v13.4s, v13.4s + scvtf v14.4s, v14.4s + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + # Multiply by input scale. + fmul v12.4s, v12.4s, v10.s[1] + fmul v13.4s, v13.4s, v10.s[3] + fmul v14.4s, v14.4s, v11.s[1] + fmul v15.4s, v15.4s, v11.s[3] + fmul v16.4s, v16.4s, v10.s[1] + fmul v17.4s, v17.4s, v10.s[3] + fmul v18.4s, v18.4s, v11.s[1] + fmul v19.4s, v19.4s, v11.s[3] + # Load weights scale. + ldp q2, q3, [x5, 0] + add x5, x5, 32 + # Load biases. + ldp q6, q7, [x5, 0] + add x5, x5, 32 + # Multiply by weight's scale. + fmul v12.4s, v12.4s, v2.4s + fmul v13.4s, v13.4s, v2.4s + fmul v14.4s, v14.4s, v2.4s + fmul v15.4s, v15.4s, v2.4s + fmul v16.4s, v16.4s, v3.4s + fmul v17.4s, v17.4s, v3.4s + fmul v18.4s, v18.4s, v3.4s + fmul v19.4s, v19.4s, v3.4s + # Add bias. + fadd v12.4s, v12.4s, v6.4s + fadd v13.4s, v13.4s, v6.4s + fadd v14.4s, v14.4s, v6.4s + fadd v15.4s, v15.4s, v6.4s + fadd v16.4s, v16.4s, v7.4s + fadd v17.4s, v17.4s, v7.4s + fadd v18.4s, v18.4s, v7.4s + fadd v19.4s, v19.4s, v7.4s + # Min/max clamping.. + fmin v12.4s, v1.4s, v12.4s + fmin v13.4s, v1.4s, v13.4s + fmin v14.4s, v1.4s, v14.4s + fmin v15.4s, v1.4s, v15.4s + fmin v16.4s, v1.4s, v16.4s + fmin v17.4s, v1.4s, v17.4s + fmin v18.4s, v1.4s, v18.4s + fmin v19.4s, v1.4s, v19.4s + fmax v12.4s, v0.4s, v12.4s + fmax v13.4s, v0.4s, v13.4s + fmax v14.4s, v0.4s, v14.4s + fmax v15.4s, v0.4s, v15.4s + fmax v16.4s, v0.4s, v16.4s + fmax v17.4s, v0.4s, v17.4s + fmax v18.4s, v0.4s, v18.4s + fmax v19.4s, v0.4s, v19.4s + + # Check whether full or partial store. + cmp x1, 8 + b.lo tail_4 + stp q12, q16, [x6], 32 + stp q13, q17, [x13], 32 + stp q14, q18, [x14], 32 + stp q15, q19, [x15], 32 + + sub x1, x1, 8 + b.ne outer_loop + b return + +tail_4: + tbz x1, 2, tail_2 + str q12, [x6], 16 + str q13, [x13], 16 + str q14, [x14], 16 + str q15, [x15], 16 + mov v12.16b, v16.16b + mov v13.16b, v17.16b + mov v14.16b, v18.16b + mov v15.16b, v19.16b + + +tail_2: + tbz x1, 1, tail_1 + str d12, [x6], 8 + str d13, [x13], 8 + str d14, [x14], 8 + str d15, [x15], 8 + dup d12, v12.d[1] + dup d13, v13.d[1] + dup d14, v14.d[1] + dup d15, v15.d[1] + + +tail_1: + tbz x1, 0, return + str s12, [x6] + str s13, [x13] + str s14, [x14] + str s15, [x15] + +return: + # Restore the callee saved GP registers. + ldp x19, x20, [sp, -48] + ldp x21, x22, [sp, -32] + ldp x23, x24, [sp, -16] + + # Restore callee saved q8-q15 registers. + ldp q8, q9, [sp, -176] + ldp q10, q11, [sp, -144] + ldp q12, q13, [sp, -112] + ldp q14, q15, [sp, -80] + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u2-acc2.c index 1fa1d16494e7..6f366fd42275 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u2-acc2.c @@ -64,12 +64,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u2_acc2( c3 = c2; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,15 +99,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u2_acc2( __m256i va3x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -139,11 +128,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u2_acc2( __m256i va3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u4-acc4.c index 7e61049aa8ff..0bb3bf290c3b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-avxvnni-u4-acc4.c @@ -64,12 +64,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u4_acc4( c3 = c2; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -117,23 +115,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u4_acc4( __m256i va3x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a3 + 12)); a3 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va2x2x0123 = _mm256_xor_si256(va2x2x0123, vsign_mask); - va2x3x0123 = _mm256_xor_si256(va2x3x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va3x2x0123 = _mm256_xor_si256(va3x2x0123, vsign_mask); - va3x3x0123 = _mm256_xor_si256(va3x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -181,11 +162,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u4_acc4( __m256i va3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c index 2e3ca5dcdddd..c03c406c14ea 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni-prfm.c @@ -62,12 +62,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -90,17 +88,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -132,13 +130,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni.c index 029c905fb34e..c0e93d3c2f16 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c8-minmax-avxvnni.c @@ -61,12 +61,10 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -89,17 +87,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -129,13 +127,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..7e6dfb52b2bd --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,193 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 784 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Load quantization params pointer from stack + mov r11, [rsp + 872] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + # Load quantization_params pointer from stack + mov r11, [rsp + 872] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [r10], zmm5 + vmovups [r13], zmm12 + vmovups [rbx], zmm14 + vmovups [rbp], zmm15 + vmovups [r8], zmm16 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [r8]{k1}, zmm16 + +return: + add rsp, 784 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c index 2e4158b868f0..a0efe2c9946b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -96,20 +94,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -138,15 +136,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c index 6ea1756556cd..f5069f7814f2 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c4-minmax-avx512vnni.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -95,20 +93,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -135,15 +133,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni( vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, vacc1x4x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni-prfm.c index 45589965a86b..6ed501485e33 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni-prfm.c @@ -69,13 +69,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -101,20 +99,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -152,15 +150,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni.c index fa4a9b378b48..1a0421e46d21 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16c8-minmax-avx512vnni.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -100,20 +98,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -147,15 +145,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..505d355c8ab5 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,244 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 784 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Load quantization params pointer from stack + mov r11, [rsp + 872] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 720] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + # Load quantization_params pointer from stack + mov r11, [rsp + 872] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 4]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 12]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 20]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 28]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 36]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm17 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm18 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm19 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm20 + vmovups [r8], zmm16 + vmovups [r8 + 64], zmm21 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r8]{k1}, zmm16 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm21 + +return: + add rsp, 784 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..68e56d0a1bae --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,347 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 784 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Load quantization params pointer from stack + mov r11, [rsp + 872] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm22, zmm8, ZMMWORD PTR [rsp + 464] + vpmulld zmm23, zmm8, ZMMWORD PTR [rsp + 528] + vpmulld zmm24, zmm8, ZMMWORD PTR [rsp + 592] + vpmulld zmm25, zmm8, ZMMWORD PTR [rsp + 656] + vpmulld zmm26, zmm8, ZMMWORD PTR [rsp + 720] + vpmulld zmm27, zmm9, ZMMWORD PTR [rsp + 464] + vpmulld zmm28, zmm9, ZMMWORD PTR [rsp + 528] + vpmulld zmm29, zmm9, ZMMWORD PTR [rsp + 592] + vpmulld zmm30, zmm9, ZMMWORD PTR [rsp + 656] + vpmulld zmm4, zmm9, ZMMWORD PTR [rsp + 720] + add r9, 256 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm17, zmm2, zmm7 + vpdpbusd zmm22, zmm2, zmm8 + vpdpbusd zmm27, zmm2, zmm9 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpdpbusd zmm23, zmm2, zmm8 + vpdpbusd zmm28, zmm2, zmm9 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpdpbusd zmm24, zmm2, zmm8 + vpdpbusd zmm29, zmm2, zmm9 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpdpbusd zmm25, zmm2, zmm8 + vpdpbusd zmm30, zmm2, zmm9 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpdpbusd zmm26, zmm2, zmm8 + vpdpbusd zmm4, zmm2, zmm9 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + vcvtdq2ps zmm30, zmm30 + vcvtdq2ps zmm4, zmm4 + # Load quantization_params pointer from stack + mov r11, [rsp + 872] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 4]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 12]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 20]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 28]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 36]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 4]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 12]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 20]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 28]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 36]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 4]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 12]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 20]{1to16} + vmulps zmm30, zmm30, DWORD PTR [r11 + 28]{1to16} + vmulps zmm4, zmm4, DWORD PTR [r11 + 36]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + vmovaps zmm2, [r9 + 128] + vmovaps zmm3, [r9 + 192] + add r9, 256 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vmovaps zmm8, [r9 + 128] + vmovaps zmm9, [r9 + 192] + add r9, 256 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm11, zmm7 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm2, zmm8 + vfmadd213ps zmm23, zmm2, zmm8 + vfmadd213ps zmm24, zmm2, zmm8 + vfmadd213ps zmm25, zmm2, zmm8 + vfmadd213ps zmm26, zmm2, zmm8 + vfmadd213ps zmm27, zmm3, zmm9 + vfmadd213ps zmm28, zmm3, zmm9 + vfmadd213ps zmm29, zmm3, zmm9 + vfmadd213ps zmm30, zmm3, zmm9 + vfmadd213ps zmm4, zmm3, zmm9 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vminps zmm30, zmm1, zmm30 + vminps zmm4, zmm1, zmm4 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + vmaxps zmm30, zmm0, zmm30 + vmaxps zmm4, zmm0, zmm4 + + # Check whether full or partial store. + cmp rcx, 64 + jl tail + + vmovups [r10], zmm5 + vmovups [r10 + 64], zmm17 + vmovups [r10 + 128], zmm22 + vmovups [r10 + 192], zmm27 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm18 + vmovups [r13 + 128], zmm23 + vmovups [r13 + 192], zmm28 + vmovups [rbx], zmm14 + vmovups [rbx + 64], zmm19 + vmovups [rbx + 128], zmm24 + vmovups [rbx + 192], zmm29 + vmovups [rbp], zmm15 + vmovups [rbp + 64], zmm20 + vmovups [rbp + 128], zmm25 + vmovups [rbp + 192], zmm30 + vmovups [r8], zmm16 + vmovups [r8 + 64], zmm21 + vmovups [r8 + 128], zmm26 + vmovups [r8 + 192], zmm4 + add r10, 256 + add r13, 256 + add rbx, 256 + add rbp, 256 + add r8, 256 + + sub rcx, 64 + jne outer_loop + jmp return + +tail: + mov r11, -1 + sal r11, cl + not r11 + kmovw k1, r11d + shr r11, 16 + kmovw k2, r11d + shr r11, 16 + kmovw k3, r11d + shr r11, 16 + kmovw k4, r11d + + vmovups ZMMWORD PTR [r10]{k1}, zmm5 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm22 + vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm27 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm23 + vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm28 + vmovups ZMMWORD PTR [rbx]{k1}, zmm14 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm24 + vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm29 + vmovups ZMMWORD PTR [rbp]{k1}, zmm15 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [rbp + 128]{k3}, zmm25 + vmovups ZMMWORD PTR [rbp + 192]{k4}, zmm30 + vmovups ZMMWORD PTR [r8]{k1}, zmm16 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r8 + 128]{k3}, zmm26 + vmovups ZMMWORD PTR [r8 + 192]{k4}, zmm4 + +return: + add rsp, 784 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u2-acc2.c index 5f8520af8143..16166f05bb0b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u2-acc2.c @@ -70,13 +70,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u2_acc2( c4 = c3; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -113,17 +111,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u2_acc2( __m256i va4x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -158,12 +145,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u2_acc2( __m256i va4x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u4-acc4.c index bb7a26ddfb05..ce10c0f45fa3 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c4-minmax-avxvnni-u4-acc4.c @@ -70,13 +70,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u4_acc4( c4 = c3; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -133,27 +131,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u4_acc4( __m256i va4x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a4 + 12)); a4 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va2x2x0123 = _mm256_xor_si256(va2x2x0123, vsign_mask); - va2x3x0123 = _mm256_xor_si256(va2x3x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va3x2x0123 = _mm256_xor_si256(va3x2x0123, vsign_mask); - va3x3x0123 = _mm256_xor_si256(va3x3x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va4x2x0123 = _mm256_xor_si256(va4x2x0123, vsign_mask); - va4x3x0123 = _mm256_xor_si256(va4x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -210,12 +187,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u4_acc4( __m256i va4x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c index 7a79c836bc2b..c81d0ec1ba56 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -100,20 +98,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -149,15 +147,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni.c index 204a72bf9f1d..dfe130ce5064 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avx256vnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -99,20 +97,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -146,15 +144,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c index 9291aad7adef..4a57275499ec 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni-prfm.c @@ -68,13 +68,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -100,20 +98,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -149,15 +147,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni.c index 1e22d9ec1e06..02bf5b0031f6 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x8c8-minmax-avxvnni.c @@ -67,13 +67,11 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -99,20 +97,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -146,15 +144,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..651c87fedbca --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,259 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 848 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 936] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + # Load quantization_params pointer from stack + mov r11, [rsp + 936] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + +return: + add rsp, 848 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..96605c933cd1 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,319 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 848 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 936] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 784] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm18, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + # Load quantization_params pointer from stack + mov r11, [rsp + 936] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 4]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 12]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 20]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 28]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 36]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 44]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm11, zmm7 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm18 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm19 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm20 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm21 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm22 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm23 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm23 + +return: + add rsp, 848 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u2-acc2.c index d3052fae1cb7..3ee5553febcf 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u2-acc2.c @@ -76,14 +76,12 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u2_acc2( c5 = c4; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -125,19 +123,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u2_acc2( __m256i va5x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va5x0x0123 = _mm256_xor_si256(va5x0x0123, vsign_mask); - va5x1x0123 = _mm256_xor_si256(va5x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -177,13 +162,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u2_acc2( __m256i va5x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - va5x0123 = _mm256_xor_si256(va5x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u4-acc4.c index 8dbfcf276ff3..0bd92c73f1dc 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c4-minmax-avxvnni-u4-acc4.c @@ -76,14 +76,12 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u4_acc4( c5 = c4; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -149,31 +147,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u4_acc4( __m256i va5x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a5 + 12)); a5 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va2x2x0123 = _mm256_xor_si256(va2x2x0123, vsign_mask); - va2x3x0123 = _mm256_xor_si256(va2x3x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va3x2x0123 = _mm256_xor_si256(va3x2x0123, vsign_mask); - va3x3x0123 = _mm256_xor_si256(va3x3x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va4x2x0123 = _mm256_xor_si256(va4x2x0123, vsign_mask); - va4x3x0123 = _mm256_xor_si256(va4x3x0123, vsign_mask); - va5x0x0123 = _mm256_xor_si256(va5x0x0123, vsign_mask); - va5x1x0123 = _mm256_xor_si256(va5x1x0123, vsign_mask); - va5x2x0123 = _mm256_xor_si256(va5x2x0123, vsign_mask); - va5x3x0123 = _mm256_xor_si256(va5x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -239,13 +212,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u4_acc4( __m256i va5x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - va5x0123 = _mm256_xor_si256(va5x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c index 82bcd136eb6f..b53dc2c5b47e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni-prfm.c @@ -74,14 +74,12 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -110,23 +108,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -166,17 +164,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni.c index c9064cba95a6..657276bd63e6 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x8c8-minmax-avxvnni.c @@ -73,14 +73,12 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -109,23 +107,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -163,17 +161,17 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..6ce0f8014c4a --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,288 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 912 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1000] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + # Load quantization_params pointer from stack + mov r11, [rsp + 1000] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + +return: + add rsp, 912 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512amx.c index db1c7ba5c4a0..90cfdb8e94b4 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,7 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; + __attribute__((aligned(64))) int32_t res[1][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -144,9 +149,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[2].zero_point)); @@ -154,13 +169,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512amx( __m512i vacc4x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[4].zero_point)); __m512i vacc5x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[5].zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c index ad1a28c41dba..50ce7a19ab87 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -114,26 +112,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -168,19 +166,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c index 6097cc75433e..9bf995132673 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c4-minmax-avx512vnni.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -113,26 +111,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -165,19 +163,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni( vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, vacc1x6x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni-prfm.c index c500d48b86c1..610a7a163ad6 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni-prfm.c @@ -81,15 +81,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -121,26 +119,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -186,19 +184,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni.c index 777a6e537405..dbb94c467634 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16c8-minmax-avx512vnni.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -120,26 +118,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -181,19 +179,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..1f5c69e519c5 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,357 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 912 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1000] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 848] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm19, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + # Load quantization_params pointer from stack + mov r11, [rsp + 1000] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 4]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 12]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 20]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 28]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 36]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 44]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 52]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm11, zmm7 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm19 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm20 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm21 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm22 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm23 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm24 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm25 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm25 + +return: + add rsp, 912 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32c4-minmax-avx512amx.c index 0af1789cb113..48c5720f633e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; + __attribute__((aligned(64))) int32_t res[2][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -151,10 +155,20 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[1].zero_point)); @@ -169,20 +183,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__avx512amx( __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[5].zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x64c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x64c4-minmax-avx512amx.c index a1a387da0fe2..84cf3920cc44 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -165,12 +167,22 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[0].zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[0].zero_point)); @@ -199,34 +211,35 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx( __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params[6].zero_point)); __m512i vacc6xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params[6].zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + // Add tile to bias + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u2-acc2.c index 6e33dadc2df2..4adcd4cdecfc 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u2-acc2.c @@ -82,15 +82,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u2_acc2( c6 = c5; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -137,21 +135,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u2_acc2( __m256i va6x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va5x0x0123 = _mm256_xor_si256(va5x0x0123, vsign_mask); - va5x1x0123 = _mm256_xor_si256(va5x1x0123, vsign_mask); - va6x0x0123 = _mm256_xor_si256(va6x0x0123, vsign_mask); - va6x1x0123 = _mm256_xor_si256(va6x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -196,14 +179,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u2_acc2( __m256i va6x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - va5x0123 = _mm256_xor_si256(va5x0123, vsign_mask); - va6x0123 = _mm256_xor_si256(va6x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u4-acc4.c index 4d230429d3c0..430bab29503a 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c4-minmax-avxvnni-u4-acc4.c @@ -82,15 +82,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u4_acc4( c6 = c5; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -165,35 +163,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u4_acc4( __m256i va6x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a6 + 12)); a6 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va2x2x0123 = _mm256_xor_si256(va2x2x0123, vsign_mask); - va2x3x0123 = _mm256_xor_si256(va2x3x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va3x2x0123 = _mm256_xor_si256(va3x2x0123, vsign_mask); - va3x3x0123 = _mm256_xor_si256(va3x3x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va4x2x0123 = _mm256_xor_si256(va4x2x0123, vsign_mask); - va4x3x0123 = _mm256_xor_si256(va4x3x0123, vsign_mask); - va5x0x0123 = _mm256_xor_si256(va5x0x0123, vsign_mask); - va5x1x0123 = _mm256_xor_si256(va5x1x0123, vsign_mask); - va5x2x0123 = _mm256_xor_si256(va5x2x0123, vsign_mask); - va5x3x0123 = _mm256_xor_si256(va5x3x0123, vsign_mask); - va6x0x0123 = _mm256_xor_si256(va6x0x0123, vsign_mask); - va6x1x0123 = _mm256_xor_si256(va6x1x0123, vsign_mask); - va6x2x0123 = _mm256_xor_si256(va6x2x0123, vsign_mask); - va6x3x0123 = _mm256_xor_si256(va6x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -268,14 +237,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u4_acc4( __m256i va6x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - va5x0123 = _mm256_xor_si256(va5x0123, vsign_mask); - va6x0123 = _mm256_xor_si256(va6x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c index f193a0cdf06b..63bf1c0a4e20 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -120,26 +118,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -183,19 +181,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni.c index cc104dddae61..b9e304931311 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avx256vnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -119,26 +117,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -180,19 +178,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c index 8b6153867b33..cf78a66aff61 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni-prfm.c @@ -80,15 +80,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -120,26 +118,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -183,19 +181,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni.c index 6d39450326e1..e848fd7b51a1 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x8c8-minmax-avxvnni.c @@ -79,15 +79,13 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -119,26 +117,26 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -180,19 +178,19 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..59ad3e9cbacb --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,317 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 976 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1064] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + # Load quantization_params pointer from stack + mov r11, [rsp + 1064] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + +return: + add rsp, 976 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c index 13cebfbd9da7..478351ffd481 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -123,29 +121,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -183,21 +181,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c index b88bef9eff6f..83cf51369fae 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c4-minmax-avx512vnni.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -122,29 +120,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -180,21 +178,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni( vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, vacc1x7x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni-prfm.c index dca537660712..73d64e9be81d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni-prfm.c @@ -87,16 +87,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,29 +129,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -203,21 +201,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni.c index ea2c1ce9660f..7fc7920d2241 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16c8-minmax-avx512vnni.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -130,29 +128,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -198,21 +196,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..ce9a3bc9cef7 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,395 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 976 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Load quantization params pointer from stack + mov r11, [rsp + 1064] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 912] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm20, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + # Load quantization_params pointer from stack + mov r11, [rsp + 1064] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 4]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 12]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 20]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 28]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 36]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 44]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 52]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 60]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm11, zmm7 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm20 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm21 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm22 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm23 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm24 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm25 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm26 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm27 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm20 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm27 + +return: + add rsp, 976 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u2-acc2.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u2-acc2.c index 3ecfd5933911..a9fc77c8393f 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u2-acc2.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u2-acc2.c @@ -88,16 +88,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u2_acc2( c7 = c6; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -149,23 +147,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u2_acc2( __m256i va7x1x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va5x0x0123 = _mm256_xor_si256(va5x0x0123, vsign_mask); - va5x1x0123 = _mm256_xor_si256(va5x1x0123, vsign_mask); - va6x0x0123 = _mm256_xor_si256(va6x0x0123, vsign_mask); - va6x1x0123 = _mm256_xor_si256(va6x1x0123, vsign_mask); - va7x0x0123 = _mm256_xor_si256(va7x0x0123, vsign_mask); - va7x1x0123 = _mm256_xor_si256(va7x1x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); w = (const int8_t*) w + 64; @@ -215,15 +196,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u2_acc2( __m256i va7x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - va5x0123 = _mm256_xor_si256(va5x0123, vsign_mask); - va6x0123 = _mm256_xor_si256(va6x0123, vsign_mask); - va7x0123 = _mm256_xor_si256(va7x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u4-acc4.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u4-acc4.c index 2e9bc9b5f412..4246b5f01584 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u4-acc4.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c4-minmax-avxvnni-u4-acc4.c @@ -88,16 +88,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u4_acc4( c7 = c6; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -181,39 +179,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u4_acc4( __m256i va7x3x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a7 + 12)); a7 += 16; - va0x0x0123 = _mm256_xor_si256(va0x0x0123, vsign_mask); - va0x1x0123 = _mm256_xor_si256(va0x1x0123, vsign_mask); - va0x2x0123 = _mm256_xor_si256(va0x2x0123, vsign_mask); - va0x3x0123 = _mm256_xor_si256(va0x3x0123, vsign_mask); - va1x0x0123 = _mm256_xor_si256(va1x0x0123, vsign_mask); - va1x1x0123 = _mm256_xor_si256(va1x1x0123, vsign_mask); - va1x2x0123 = _mm256_xor_si256(va1x2x0123, vsign_mask); - va1x3x0123 = _mm256_xor_si256(va1x3x0123, vsign_mask); - va2x0x0123 = _mm256_xor_si256(va2x0x0123, vsign_mask); - va2x1x0123 = _mm256_xor_si256(va2x1x0123, vsign_mask); - va2x2x0123 = _mm256_xor_si256(va2x2x0123, vsign_mask); - va2x3x0123 = _mm256_xor_si256(va2x3x0123, vsign_mask); - va3x0x0123 = _mm256_xor_si256(va3x0x0123, vsign_mask); - va3x1x0123 = _mm256_xor_si256(va3x1x0123, vsign_mask); - va3x2x0123 = _mm256_xor_si256(va3x2x0123, vsign_mask); - va3x3x0123 = _mm256_xor_si256(va3x3x0123, vsign_mask); - va4x0x0123 = _mm256_xor_si256(va4x0x0123, vsign_mask); - va4x1x0123 = _mm256_xor_si256(va4x1x0123, vsign_mask); - va4x2x0123 = _mm256_xor_si256(va4x2x0123, vsign_mask); - va4x3x0123 = _mm256_xor_si256(va4x3x0123, vsign_mask); - va5x0x0123 = _mm256_xor_si256(va5x0x0123, vsign_mask); - va5x1x0123 = _mm256_xor_si256(va5x1x0123, vsign_mask); - va5x2x0123 = _mm256_xor_si256(va5x2x0123, vsign_mask); - va5x3x0123 = _mm256_xor_si256(va5x3x0123, vsign_mask); - va6x0x0123 = _mm256_xor_si256(va6x0x0123, vsign_mask); - va6x1x0123 = _mm256_xor_si256(va6x1x0123, vsign_mask); - va6x2x0123 = _mm256_xor_si256(va6x2x0123, vsign_mask); - va6x3x0123 = _mm256_xor_si256(va6x3x0123, vsign_mask); - va7x0x0123 = _mm256_xor_si256(va7x0x0123, vsign_mask); - va7x1x0123 = _mm256_xor_si256(va7x1x0123, vsign_mask); - va7x2x0123 = _mm256_xor_si256(va7x2x0123, vsign_mask); - va7x3x0123 = _mm256_xor_si256(va7x3x0123, vsign_mask); - const __m256i vb0x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 0)); const __m256i vb1x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)); const __m256i vb2x01234567 = _mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)); @@ -297,15 +262,6 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u4_acc4( __m256i va7x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - va0x0123 = _mm256_xor_si256(va0x0123, vsign_mask); - va1x0123 = _mm256_xor_si256(va1x0123, vsign_mask); - va2x0123 = _mm256_xor_si256(va2x0123, vsign_mask); - va3x0123 = _mm256_xor_si256(va3x0123, vsign_mask); - va4x0123 = _mm256_xor_si256(va4x0123, vsign_mask); - va5x0123 = _mm256_xor_si256(va5x0123, vsign_mask); - va6x0123 = _mm256_xor_si256(va6x0123, vsign_mask); - va7x0123 = _mm256_xor_si256(va7x0123, vsign_mask); - const __m256i vb01234567 = _mm256_load_si256(w); vacc0x0x01234567 = _mm256_dpbusd_avx_epi32(vacc0x0x01234567, va0x0123, vb01234567); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c index 4db3b6645b70..0e39ec001a27 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -130,29 +128,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -200,21 +198,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni.c index a280ce0986d2..2db77b7da59c 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avx256vnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -129,29 +127,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -197,21 +195,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c index 465ea3b95200..0f84d37f1619 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni-prfm.c @@ -86,16 +86,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -130,29 +128,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -200,21 +198,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni.c index 21460749556f..d6732b20d96c 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x8c8-minmax-avxvnni.c @@ -85,16 +85,14 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -129,29 +127,29 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -197,21 +195,21 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..c778316dfdb2 --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,346 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1040 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1128] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + add r9, 64 + +inner_loop: + vmovaps zmm6, [r9 + 0] + add r9, 64 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + # Load quantization_params pointer from stack + mov r11, [rsp + 1128] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmovaps zmm10, [r9 + 0] + add r9, 64 + vmovaps zmm6, [r9 + 0] + add r9, 64 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 16 + jl tail + + vmovups [rsi], zmm5 + vmovups [rax], zmm12 + vmovups [r15], zmm14 + vmovups [r14], zmm15 + vmovups [r12], zmm16 + vmovups [r10], zmm17 + vmovups [r13], zmm18 + vmovups [rbx], zmm19 + vmovups [rbp], zmm20 + add rsi, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 16 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + +return: + add rsp, 1040 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni-prfm.c index 4c1774b7714f..ede15c2f6123 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -132,32 +130,32 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -198,23 +196,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni.c index 7a11a39275e9..4d4a33590d6d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c4-minmax-avx512vnni.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -131,32 +129,32 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -195,23 +193,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni( vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, vacc1x8x0123456789ABCDEF); if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni-prfm.c index 43b829c9e76d..b9def5ff9ed3 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni-prfm.c @@ -93,17 +93,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -141,32 +139,32 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -220,23 +218,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni.c index 840a5a155887..f27115b437d2 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point); + const __m512i vinput_zero_point4 = _mm512_set1_epi32((int) quantization_params[4].zero_point); + const __m512i vinput_zero_point5 = _mm512_set1_epi32((int) quantization_params[5].zero_point); + const __m512i vinput_zero_point6 = _mm512_set1_epi32((int) quantization_params[6].zero_point); + const __m512i vinput_zero_point7 = _mm512_set1_epi32((int) quantization_params[7].zero_point); + const __m512i vinput_zero_point8 = _mm512_set1_epi32((int) quantization_params[8].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -140,32 +138,32 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -215,23 +213,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S new file mode 100644 index 000000000000..642dcbecfacc --- /dev/null +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S @@ -0,0 +1,433 @@ +#include "xnnpack/assembly.h" + +BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # Swap rsi & rcx because sal can only use cl. + mov r15, rsi + mov rsi, rcx + mov rcx, r15 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + add rdx, 3 + and rdx, -4 + sub rsp, 1040 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp - 128], rsi + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp - 136], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 144], rax + mov [rsp - 152], r13 + + # Clamp a & c pointers if mr <= 2 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 160], rsi + mov [rsp - 168], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 176], rax + mov [rsp - 184], r13 + + # Clamp a & c pointers if mr <= 4 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 192], rsi + mov [rsp - 200], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 208], rax + mov [rsp - 216], r13 + + # Clamp a & c pointers if mr <= 6 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 224], rsi + mov [rsp - 232], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rsi + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rsi + cmovle r13, r10 + + mov [rsp - 240], rax + mov [rsp - 248], r13 + + # Clamp a & c pointers if mr <= 8 + mov rsi, rax + add rsi, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rsi, rax + cmovle r10, r13 + + mov [rsp - 256], rsi + mov [rsp - 264], r10 + + # Load quantization params pointer from stack + mov r11, [rsp + 1128] + mov edi, [r11 + 0] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 464], zmm6 + mov edi, [r11 + 8] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 528], zmm6 + mov edi, [r11 + 16] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 592], zmm6 + mov edi, [r11 + 24] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 656], zmm6 + mov edi, [r11 + 32] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 720], zmm6 + mov edi, [r11 + 40] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 784], zmm6 + mov edi, [r11 + 48] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 848], zmm6 + mov edi, [r11 + 56] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 912], zmm6 + mov edi, [r11 + 64] + vpbroadcastd zmm6, edi + vmovups zmmword ptr [rsp + 976], zmm6 + +outer_loop: + # Zero k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rsi, [rsp - 128] + mov rax, [rsp - 144] + mov r15, [rsp - 160] + mov r14, [rsp - 176] + mov r12, [rsp - 192] + mov r10, [rsp - 208] + mov r13, [rsp - 224] + mov rbx, [rsp - 240] + mov rbp, [rsp - 256] + + # Initialize accumulators with k_sum * input zero point. + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + vpmulld zmm5, zmm6, ZMMWORD PTR [rsp + 464] + vpmulld zmm12, zmm6, ZMMWORD PTR [rsp + 528] + vpmulld zmm14, zmm6, ZMMWORD PTR [rsp + 592] + vpmulld zmm15, zmm6, ZMMWORD PTR [rsp + 656] + vpmulld zmm16, zmm6, ZMMWORD PTR [rsp + 720] + vpmulld zmm17, zmm6, ZMMWORD PTR [rsp + 784] + vpmulld zmm18, zmm6, ZMMWORD PTR [rsp + 848] + vpmulld zmm19, zmm6, ZMMWORD PTR [rsp + 912] + vpmulld zmm20, zmm6, ZMMWORD PTR [rsp + 976] + vpmulld zmm21, zmm7, ZMMWORD PTR [rsp + 464] + vpmulld zmm22, zmm7, ZMMWORD PTR [rsp + 528] + vpmulld zmm23, zmm7, ZMMWORD PTR [rsp + 592] + vpmulld zmm24, zmm7, ZMMWORD PTR [rsp + 656] + vpmulld zmm25, zmm7, ZMMWORD PTR [rsp + 720] + vpmulld zmm26, zmm7, ZMMWORD PTR [rsp + 784] + vpmulld zmm27, zmm7, ZMMWORD PTR [rsp + 848] + vpmulld zmm28, zmm7, ZMMWORD PTR [rsp + 912] + vpmulld zmm29, zmm7, ZMMWORD PTR [rsp + 976] + add r9, 128 + +inner_loop: + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vpbroadcastd zmm2, [rsi + r11] + vpdpbusd zmm5, zmm2, zmm6 + vpdpbusd zmm21, zmm2, zmm7 + vpbroadcastd zmm2, [rax + r11] + vpdpbusd zmm12, zmm2, zmm6 + vpdpbusd zmm22, zmm2, zmm7 + vpbroadcastd zmm2, [r15 + r11] + vpdpbusd zmm14, zmm2, zmm6 + vpdpbusd zmm23, zmm2, zmm7 + vpbroadcastd zmm2, [r14 + r11] + vpdpbusd zmm15, zmm2, zmm6 + vpdpbusd zmm24, zmm2, zmm7 + vpbroadcastd zmm2, [r12 + r11] + vpdpbusd zmm16, zmm2, zmm6 + vpdpbusd zmm25, zmm2, zmm7 + vpbroadcastd zmm2, [r10 + r11] + vpdpbusd zmm17, zmm2, zmm6 + vpdpbusd zmm26, zmm2, zmm7 + vpbroadcastd zmm2, [r13 + r11] + vpdpbusd zmm18, zmm2, zmm6 + vpdpbusd zmm27, zmm2, zmm7 + vpbroadcastd zmm2, [rbx + r11] + vpdpbusd zmm19, zmm2, zmm6 + vpdpbusd zmm28, zmm2, zmm7 + vpbroadcastd zmm2, [rbp + r11] + vpdpbusd zmm20, zmm2, zmm6 + vpdpbusd zmm29, zmm2, zmm7 + + add r11, 4 + cmp rdx, r11 + jne inner_loop + + # Convert from int32 to float. + vcvtdq2ps zmm5, zmm5 + vcvtdq2ps zmm12, zmm12 + vcvtdq2ps zmm14, zmm14 + vcvtdq2ps zmm15, zmm15 + vcvtdq2ps zmm16, zmm16 + vcvtdq2ps zmm17, zmm17 + vcvtdq2ps zmm18, zmm18 + vcvtdq2ps zmm19, zmm19 + vcvtdq2ps zmm20, zmm20 + vcvtdq2ps zmm21, zmm21 + vcvtdq2ps zmm22, zmm22 + vcvtdq2ps zmm23, zmm23 + vcvtdq2ps zmm24, zmm24 + vcvtdq2ps zmm25, zmm25 + vcvtdq2ps zmm26, zmm26 + vcvtdq2ps zmm27, zmm27 + vcvtdq2ps zmm28, zmm28 + vcvtdq2ps zmm29, zmm29 + # Load quantization_params pointer from stack + mov r11, [rsp + 1128] + vmulps zmm5, zmm5, DWORD PTR [r11 + 4]{1to16} + vmulps zmm12, zmm12, DWORD PTR [r11 + 12]{1to16} + vmulps zmm14, zmm14, DWORD PTR [r11 + 20]{1to16} + vmulps zmm15, zmm15, DWORD PTR [r11 + 28]{1to16} + vmulps zmm16, zmm16, DWORD PTR [r11 + 36]{1to16} + vmulps zmm17, zmm17, DWORD PTR [r11 + 44]{1to16} + vmulps zmm18, zmm18, DWORD PTR [r11 + 52]{1to16} + vmulps zmm19, zmm19, DWORD PTR [r11 + 60]{1to16} + vmulps zmm20, zmm20, DWORD PTR [r11 + 68]{1to16} + vmulps zmm21, zmm21, DWORD PTR [r11 + 4]{1to16} + vmulps zmm22, zmm22, DWORD PTR [r11 + 12]{1to16} + vmulps zmm23, zmm23, DWORD PTR [r11 + 20]{1to16} + vmulps zmm24, zmm24, DWORD PTR [r11 + 28]{1to16} + vmulps zmm25, zmm25, DWORD PTR [r11 + 36]{1to16} + vmulps zmm26, zmm26, DWORD PTR [r11 + 44]{1to16} + vmulps zmm27, zmm27, DWORD PTR [r11 + 52]{1to16} + vmulps zmm28, zmm28, DWORD PTR [r11 + 60]{1to16} + vmulps zmm29, zmm29, DWORD PTR [r11 + 68]{1to16} + vmovaps zmm10, [r9 + 0] + vmovaps zmm11, [r9 + 64] + add r9, 128 + vmovaps zmm6, [r9 + 0] + vmovaps zmm7, [r9 + 64] + add r9, 128 + vfmadd213ps zmm5, zmm10, zmm6 + vfmadd213ps zmm12, zmm10, zmm6 + vfmadd213ps zmm14, zmm10, zmm6 + vfmadd213ps zmm15, zmm10, zmm6 + vfmadd213ps zmm16, zmm10, zmm6 + vfmadd213ps zmm17, zmm10, zmm6 + vfmadd213ps zmm18, zmm10, zmm6 + vfmadd213ps zmm19, zmm10, zmm6 + vfmadd213ps zmm20, zmm10, zmm6 + vfmadd213ps zmm21, zmm11, zmm7 + vfmadd213ps zmm22, zmm11, zmm7 + vfmadd213ps zmm23, zmm11, zmm7 + vfmadd213ps zmm24, zmm11, zmm7 + vfmadd213ps zmm25, zmm11, zmm7 + vfmadd213ps zmm26, zmm11, zmm7 + vfmadd213ps zmm27, zmm11, zmm7 + vfmadd213ps zmm28, zmm11, zmm7 + vfmadd213ps zmm29, zmm11, zmm7 + # Min/max clamping.. + vminps zmm5, zmm1, zmm5 + vminps zmm12, zmm1, zmm12 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vminps zmm22, zmm1, zmm22 + vminps zmm23, zmm1, zmm23 + vminps zmm24, zmm1, zmm24 + vminps zmm25, zmm1, zmm25 + vminps zmm26, zmm1, zmm26 + vminps zmm27, zmm1, zmm27 + vminps zmm28, zmm1, zmm28 + vminps zmm29, zmm1, zmm29 + vmaxps zmm5, zmm0, zmm5 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + vmaxps zmm22, zmm0, zmm22 + vmaxps zmm23, zmm0, zmm23 + vmaxps zmm24, zmm0, zmm24 + vmaxps zmm25, zmm0, zmm25 + vmaxps zmm26, zmm0, zmm26 + vmaxps zmm27, zmm0, zmm27 + vmaxps zmm28, zmm0, zmm28 + vmaxps zmm29, zmm0, zmm29 + + # Pop output pointers from the stack. + mov rsi, [rsp - 136] + mov rax, [rsp - 152] + mov r15, [rsp - 168] + mov r14, [rsp - 184] + mov r12, [rsp - 200] + mov r10, [rsp - 216] + mov r13, [rsp - 232] + mov rbx, [rsp - 248] + mov rbp, [rsp - 264] + + # Check whether full or partial store. + cmp rcx, 32 + jl tail + + vmovups [rsi], zmm5 + vmovups [rsi + 64], zmm21 + vmovups [rax], zmm12 + vmovups [rax + 64], zmm22 + vmovups [r15], zmm14 + vmovups [r15 + 64], zmm23 + vmovups [r14], zmm15 + vmovups [r14 + 64], zmm24 + vmovups [r12], zmm16 + vmovups [r12 + 64], zmm25 + vmovups [r10], zmm17 + vmovups [r10 + 64], zmm26 + vmovups [r13], zmm18 + vmovups [r13 + 64], zmm27 + vmovups [rbx], zmm19 + vmovups [rbx + 64], zmm28 + vmovups [rbp], zmm20 + vmovups [rbp + 64], zmm29 + add rsi, 128 + add rax, 128 + add r15, 128 + add r14, 128 + add r12, 128 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + # Write output pointers to the stack. + mov [rsp - 136], rsi + mov [rsp - 152], rax + mov [rsp - 168], r15 + mov [rsp - 184], r14 + mov [rsp - 200], r12 + mov [rsp - 216], r10 + mov [rsp - 232], r13 + mov [rsp - 248], rbx + mov [rsp - 264], rbp + + sub rcx, 32 + jne outer_loop + jmp return + +tail: + mov r11d, -1 + sal r11d, cl + not r11d + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [rsi]{k1}, zmm5 + vmovups ZMMWORD PTR [rsi + 64]{k2}, zmm21 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [rax + 64]{k2}, zmm22 + vmovups ZMMWORD PTR [r15]{k1}, zmm14 + vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm23 + vmovups ZMMWORD PTR [r14]{k1}, zmm15 + vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm24 + vmovups ZMMWORD PTR [r12]{k1}, zmm16 + vmovups ZMMWORD PTR [r12 + 64]{k2}, zmm25 + vmovups ZMMWORD PTR [r10]{k1}, zmm17 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm26 + vmovups ZMMWORD PTR [r13]{k1}, zmm18 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm27 + vmovups ZMMWORD PTR [rbx]{k1}, zmm19 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm28 + vmovups ZMMWORD PTR [rbp]{k1}, zmm20 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm29 + +return: + add rsp, 1040 + + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni \ No newline at end of file diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c index e1da795ce6c1..c3cd93cbbe99 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni-prfm.c @@ -92,17 +92,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -140,32 +138,32 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -217,23 +215,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni.c index f9350a2f9aa9..a4ed2c6a496b 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x8c8-minmax-avx256vnni.c @@ -91,17 +91,15 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point + 128); - const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point + 128); - const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point + 128); - const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point + 128); - const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point + 128); - const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point + 128); - const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point + 128); - const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point + 128); - const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point + 128); + const __m256i vinput_zero_point0 = _mm256_set1_epi32((int) quantization_params[0].zero_point); + const __m256i vinput_zero_point1 = _mm256_set1_epi32((int) quantization_params[1].zero_point); + const __m256i vinput_zero_point2 = _mm256_set1_epi32((int) quantization_params[2].zero_point); + const __m256i vinput_zero_point3 = _mm256_set1_epi32((int) quantization_params[3].zero_point); + const __m256i vinput_zero_point4 = _mm256_set1_epi32((int) quantization_params[4].zero_point); + const __m256i vinput_zero_point5 = _mm256_set1_epi32((int) quantization_params[5].zero_point); + const __m256i vinput_zero_point6 = _mm256_set1_epi32((int) quantization_params[6].zero_point); + const __m256i vinput_zero_point7 = _mm256_set1_epi32((int) quantization_params[7].zero_point); + const __m256i vinput_zero_point8 = _mm256_set1_epi32((int) quantization_params[8].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -139,32 +137,32 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -214,23 +212,23 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni-prfm.c index 83a97f5dcbe5..c2295814606d 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni-prfm.c @@ -83,9 +83,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -183,35 +181,35 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -245,25 +243,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni.c index fb81a4056e13..310a861ae7f7 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c4-minmax-avx512vnni.c @@ -82,9 +82,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -182,35 +180,35 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -242,25 +240,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni-prfm.c index 8e4e271ae435..ed2a7634dce2 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni-prfm.c @@ -83,9 +83,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -184,35 +182,35 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -270,25 +268,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni.c index ce5ff551d77c..8cb65d159815 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni.c @@ -82,9 +82,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni( c9 = c8; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -183,35 +181,35 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -265,25 +263,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c index 549f1ef198c8..748c7017786c 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni-prfm.c @@ -83,9 +83,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -184,35 +182,35 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -268,25 +266,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni.c index 31fc41437e62..c27d84a4c6d9 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x8c8-minmax-avx256vnni.c @@ -82,9 +82,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni( c9 = c8; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -183,35 +181,35 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -265,25 +263,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni-prfm.c index 31626544b18c..0b9ed4337c1c 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni-prfm.c @@ -91,9 +91,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -207,41 +205,41 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -279,29 +277,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni.c index e73a3666ddd6..d305ab6e56a4 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c4-minmax-avx512vnni.c @@ -90,9 +90,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -206,41 +204,41 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -276,29 +274,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni-prfm.c index b5e6a5ecd99c..f785ab401ff4 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni-prfm.c @@ -91,9 +91,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni_prfm( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -208,41 +206,41 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -308,29 +306,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni.c index cd60422bded4..f21294c79bb7 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x16c8-minmax-avx512vnni.c @@ -90,9 +90,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni( c11 = c10; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -207,41 +205,41 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -303,29 +301,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c index 9ce06d3fc930..02ede3c255f6 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni-prfm.c @@ -91,9 +91,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -208,41 +206,41 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -306,29 +304,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni.c index 105b1a6664e4..38a4aa2eb8a4 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-12x8c8-minmax-avx256vnni.c @@ -90,9 +90,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni( c11 = c10; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -207,41 +205,41 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -303,29 +301,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni-prfm.c index 694272a28674..81724c018877 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni-prfm.c @@ -99,9 +99,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -231,47 +229,47 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -313,33 +311,33 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni.c index 09e2a67be5f7..fc81905028fc 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c4-minmax-avx512vnni.c @@ -98,9 +98,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -230,47 +228,47 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); - const __m512i va9x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); + const __m512i va9x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a9 + 4)); a9 += 8; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); - const __m512i va10x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); + const __m512i va10x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a10 + 4)); a10 += 8; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); - const __m512i va11x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); + const __m512i va11x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a11 + 4)); a11 += 8; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); - const __m512i va12x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); + const __m512i va12x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a12 + 4)); a12 += 8; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); - const __m512i va13x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); + const __m512i va13x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a13 + 4)); a13 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -310,33 +308,33 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; - const __m512i va9x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a9)), vsign_mask); + const __m512i va9x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a9)); a9 += 4; - const __m512i va10x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a10)), vsign_mask); + const __m512i va10x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a10)); a10 += 4; - const __m512i va11x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a11)), vsign_mask); + const __m512i va11x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a11)); a11 += 4; - const __m512i va12x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a12)), vsign_mask); + const __m512i va12x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a12)); a12 += 4; - const __m512i va13x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a13)), vsign_mask); + const __m512i va13x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a13)); a13 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni-prfm.c index a3cab34bc6ab..c3dfbfe4e7d9 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni-prfm.c @@ -99,9 +99,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni_prfm( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -232,47 +230,47 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -346,33 +344,33 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni.c index 9d5ba4bfd01c..d51dabcc2523 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x16c8-minmax-avx512vnni.c @@ -98,9 +98,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni( c13 = c12; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -231,47 +229,47 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m512i va9x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); + const __m512i va9x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m512i va10x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); + const __m512i va10x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m512i va11x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); + const __m512i va11x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m512i va12x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); + const __m512i va12x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m512i va13x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); + const __m512i va13x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -341,33 +339,33 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m512i va9x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m512i va9x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m512i va10x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m512i va10x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m512i va11x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m512i va11x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m512i va12x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m512i va12x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m512i va13x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m512i va13x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c index 092b0e6c547d..87ac83992ace 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni-prfm.c @@ -99,9 +99,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -232,47 +230,47 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -344,33 +342,33 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c index 97b4221feafb..590d8a6821b6 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c @@ -98,9 +98,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni( c13 = c12; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -231,47 +229,47 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); - const __m256i va9x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); + const __m256i va9x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9 + 8)); a9 += 16; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); - const __m256i va10x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); + const __m256i va10x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10 + 8)); a10 += 16; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); - const __m256i va11x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); + const __m256i va11x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11 + 8)); a11 += 16; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); - const __m256i va12x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); + const __m256i va12x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12 + 8)); a12 += 16; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); - const __m256i va13x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); + const __m256i va13x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13 + 8)); a13 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -341,33 +339,33 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; - const __m256i va9x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)), vsign_mask); + const __m256i va9x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a9)); a9 += 8; - const __m256i va10x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)), vsign_mask); + const __m256i va10x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a10)); a10 += 8; - const __m256i va11x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)), vsign_mask); + const __m256i va11x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a11)); a11 += 8; - const __m256i va12x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)), vsign_mask); + const __m256i va12x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a12)); a12 += 8; - const __m256i va13x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)), vsign_mask); + const __m256i va13x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a13)); a13 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx-prfm.c index 12d0c3a6e90a..eb6694b2e464 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -47,7 +52,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +72,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -460,9 +465,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -479,22 +495,22 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx_prfm( __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx.c index d5dbdfe9ce27..6f52b9d1c2cb 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,7 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -395,9 +400,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -414,22 +430,22 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx( __m512i vacc13x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc14x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx-prfm.c index 1fd6bdd4f2c8..5126314cb69f 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -47,8 +52,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -68,19 +72,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -501,10 +505,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -537,38 +552,38 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx_prfm( __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx.c index b682e185a341..f24a445a8385 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,8 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -404,10 +408,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -440,38 +455,38 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx( __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c index 0515999379a4..f9f219164214 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -47,10 +52,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -70,19 +72,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -583,12 +585,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -653,70 +666,70 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx.c index 9ee65e1cc12f..c88ca3f75d17 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-16x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,10 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -69,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -422,12 +424,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -492,70 +505,70 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx( __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc15xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + vacc7x0123456789ABCDEF = _mm512_add_epi32(vacc7x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc7xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vacc7xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc7xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + vacc8x0123456789ABCDEF = _mm512_add_epi32(vacc8x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc8xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vacc8xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc8xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + vacc9x0123456789ABCDEF = _mm512_add_epi32(vacc9x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc9xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vacc9xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc9xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + vacc10x0123456789ABCDEF = _mm512_add_epi32(vacc10x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc10xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vacc10xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc10xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + vacc11x0123456789ABCDEF = _mm512_add_epi32(vacc11x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc11xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vacc11xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc11xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + vacc12x0123456789ABCDEF = _mm512_add_epi32(vacc12x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc12xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vacc12xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc12xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + vacc13x0123456789ABCDEF = _mm512_add_epi32(vacc13x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc13xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vacc13xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc13xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + vacc14x0123456789ABCDEF = _mm512_add_epi32(vacc14x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc14xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vacc14xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc14xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + vacc15x0123456789ABCDEF = _mm512_add_epi32(vacc15x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc15xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vacc15xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc15xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512amx.c index 179dee061acf..1860acc8d9dd 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,7 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; + __attribute__((aligned(64))) int32_t res[1][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -141,11 +146,22 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c index 780277bd6605..154c1bb889f8 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni-prfm.c @@ -47,9 +47,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni_prfm( kc = round_up_po2(kc, 4 * sizeof(int8_t)); float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -75,8 +73,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -92,7 +90,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni.c index 0498ffdc4cd4..6a6b6f06d238 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c4-minmax-avx512vnni.c @@ -46,9 +46,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni( kc = round_up_po2(kc, 4 * sizeof(int8_t)); float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -74,8 +72,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -89,7 +87,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c index 3dc2682c2c57..976799aefa9c 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c @@ -47,9 +47,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm( kc = round_up_po2(kc, 8 * sizeof(int8_t)); float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -78,8 +76,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -101,7 +99,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni.c index 99ae9e24df21..e23806f3fccd 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni.c @@ -46,9 +46,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni( kc = round_up_po2(kc, 8 * sizeof(int8_t)); float* c0 = c; - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -77,8 +75,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -96,7 +94,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x32c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x32c4-minmax-avx512amx.c index 60cab75f9bd4..5ad9a495fc2c 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,8 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x32c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; + __attribute__((aligned(64))) int32_t res[2][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -148,14 +152,25 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x32c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x64c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x64c4-minmax-avx512amx.c index 66d53397678c..05596625ec83 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,10 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -69,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -162,20 +164,31 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c index 76606f23252c..869d85b22d78 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni-prfm.c @@ -47,9 +47,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm( kc = round_up_po2(kc, 8 * sizeof(int8_t)); float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -78,8 +76,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -99,7 +97,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni.c index 7d548da570d9..afb62f3eb383 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avx256vnni.c @@ -46,9 +46,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni( kc = round_up_po2(kc, 8 * sizeof(int8_t)); float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -77,8 +75,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -96,7 +94,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c index e939e0785877..3e7aea72abce 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni-prfm.c @@ -47,9 +47,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm( kc = round_up_po2(kc, 8 * sizeof(int8_t)); float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -78,8 +76,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -99,7 +97,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni.c index 9e51892f75f1..ce419202c452 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x8c8-minmax-avxvnni.c @@ -46,9 +46,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni( kc = round_up_po2(kc, 8 * sizeof(int8_t)); float* c0 = c; - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -77,8 +75,8 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -96,7 +94,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c index da18d3747ec4..699b567772e3 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni-prfm.c @@ -51,9 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -92,11 +90,11 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -120,9 +118,9 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni.c index f556878eff14..48c184a94aed 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-2x8c8-minmax-avxvnni.c @@ -50,9 +50,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni( c1 = c0; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -91,11 +89,11 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -117,9 +115,9 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c index 09951bf7f5f0..0ac489c23223 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni-prfm.c @@ -55,9 +55,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -100,14 +98,14 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -135,11 +133,11 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni.c index db3d31ebf3c8..a0e7a77e4574 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-3x8c8-minmax-avxvnni.c @@ -54,9 +54,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni( c2 = c1; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -99,14 +97,14 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -132,11 +130,11 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c index 967ae7c7aef0..e318c5cfa91c 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni-prfm.c @@ -59,9 +59,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni_prfm( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -111,17 +109,17 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -143,13 +141,13 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni.c index 02285ca4d46f..ba796e9ad988 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x16c4-minmax-avx512vnni.c @@ -58,9 +58,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni( c3 = c2; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -110,17 +108,17 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -140,13 +138,13 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c index 0148dcc105a7..f9fb2475160c 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni-prfm.c @@ -59,9 +59,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -112,17 +110,17 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -154,13 +152,13 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni.c index 3f49b6cefb2f..a02e0e6637e7 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c8-minmax-avxvnni.c @@ -58,9 +58,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni( c3 = c2; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -111,17 +109,17 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -151,13 +149,13 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c index fe48c9bca50e..468d92f67c71 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni-prfm.c @@ -63,9 +63,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -123,20 +121,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -160,15 +158,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni.c index 9ff4f2166e3f..c7c51a6372d8 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c4-minmax-avx512vnni.c @@ -62,9 +62,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -122,20 +120,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -157,15 +155,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c index 3453536fd8f6..f9556d9a21f5 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni-prfm.c @@ -63,9 +63,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni_prfm( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -124,20 +122,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -175,15 +173,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni.c index 0a04d99a8f27..e95af562f426 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x16c8-minmax-avx512vnni.c @@ -62,9 +62,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni( c4 = c3; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -123,20 +121,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -170,15 +168,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c index 3d4c6838ac4b..6c1e09ffc2fa 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni-prfm.c @@ -63,9 +63,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -124,20 +122,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -173,15 +171,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni.c index f64b1fb4b4c6..1c76b80bf6a7 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avx256vnni.c @@ -62,9 +62,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -123,20 +121,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -170,15 +168,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c index 20de391b037f..ceab69c85e6b 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni-prfm.c @@ -63,9 +63,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -124,20 +122,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -173,15 +171,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni.c index 20d5abd9cabe..1a8b94752ab6 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-5x8c8-minmax-avxvnni.c @@ -62,9 +62,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni( c4 = c3; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -123,20 +121,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -170,15 +168,15 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c index 0ccc5f7791d4..949606165ce7 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni-prfm.c @@ -67,9 +67,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -136,23 +134,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -192,17 +190,17 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni.c index d560aa2630ec..755b59a2a199 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-6x8c8-minmax-avxvnni.c @@ -66,9 +66,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni( c5 = c4; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -135,23 +133,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -189,17 +187,17 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512amx.c index 4228a12a0b82..649bb00ccb0d 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,7 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; + __attribute__((aligned(64))) int32_t res[1][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -251,9 +256,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -261,13 +277,13 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512amx( __m512i vacc4x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc5x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c index cdccba899b2d..de1c815118b5 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni-prfm.c @@ -71,9 +71,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -147,26 +145,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -194,19 +192,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni.c index 3a8af92bc27d..476bfd67b848 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c4-minmax-avx512vnni.c @@ -70,9 +70,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -146,26 +144,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -191,19 +189,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c index e22ef4138603..5ebc2a85b726 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni-prfm.c @@ -71,9 +71,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni_prfm( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -148,26 +146,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -213,19 +211,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni.c index cb3767d7ef20..4d379cb2043e 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x16c8-minmax-avx512vnni.c @@ -70,9 +70,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni( c6 = c5; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -147,26 +145,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -208,19 +206,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x32c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x32c4-minmax-avx512amx.c index 9231711515c8..8d42482a4ba4 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x32c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x32c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,8 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x32c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; + __attribute__((aligned(64))) int32_t res[2][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -260,10 +264,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x32c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -278,20 +293,20 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x32c4__avx512amx( __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x64c4-minmax-avx512amx.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x64c4-minmax-avx512amx.c index 7d4a9661aea4..5f30f5626999 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x64c4-minmax-avx512amx.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x64c4-minmax-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -46,10 +51,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -69,19 +71,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -278,12 +280,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc0xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); @@ -312,34 +325,34 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx( __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_mullo_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6xWXYZabcdefghijkl = _mm512_mullo_epi32(vksumWXYZabcdefghijkl, _mm512_set1_epi32((int) quantization_params->zero_point)); __m512i vacc6xmnopqrstuvwxyz01 = _mm512_mullo_epi32(vksummnopqrstuvwxyz01, _mm512_set1_epi32((int) quantization_params->zero_point)); - vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + vacc0x0123456789ABCDEF = _mm512_add_epi32(vacc0x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc0xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vacc0xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc0xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + vacc1x0123456789ABCDEF = _mm512_add_epi32(vacc1x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc1xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vacc1xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc1xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + vacc2x0123456789ABCDEF = _mm512_add_epi32(vacc2x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc2xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vacc2xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc2xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + vacc3x0123456789ABCDEF = _mm512_add_epi32(vacc3x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc3xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vacc3xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc3xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + vacc4x0123456789ABCDEF = _mm512_add_epi32(vacc4x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc4xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vacc4xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc4xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + vacc5x0123456789ABCDEF = _mm512_add_epi32(vacc5x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc5xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vacc5xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc5xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + vacc6x0123456789ABCDEF = _mm512_add_epi32(vacc6x0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vacc6xGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vacc6xWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vacc6xmnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c index d50f92832aef..5ae84cd3b4d3 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni-prfm.c @@ -71,9 +71,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -148,26 +146,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -211,19 +209,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni.c index 09ab1132ea7d..25edeb2b86af 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avx256vnni.c @@ -70,9 +70,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -147,26 +145,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -208,19 +206,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c index cfd9273fc4b6..d63a8c52334e 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni-prfm.c @@ -71,9 +71,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -148,26 +146,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -211,19 +209,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c index 67a2d3cdcc07..bf098f8b2b79 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c @@ -70,9 +70,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni( c6 = c5; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -147,26 +145,26 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -208,19 +206,19 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c index e50367109e27..d21c389ebc95 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni-prfm.c @@ -75,9 +75,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -159,29 +157,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -211,21 +209,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni.c index 72cae0384a79..b7c6cd911d99 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c4-minmax-avx512vnni.c @@ -74,9 +74,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -158,29 +156,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -208,21 +206,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c index c5dd6c8aa95f..d889b9df5174 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni-prfm.c @@ -75,9 +75,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni_prfm( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -160,29 +158,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -232,21 +230,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni.c index 4d260d724aa8..60b1daaabf5b 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x16c8-minmax-avx512vnni.c @@ -74,9 +74,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni( c7 = c6; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -159,29 +157,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -227,21 +225,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c index f764d7fadb55..84be41470aae 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni-prfm.c @@ -75,9 +75,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -160,29 +158,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -230,21 +228,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni.c index 6152cc4ed3fa..686569654bd4 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avx256vnni.c @@ -74,9 +74,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -159,29 +157,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -227,21 +225,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c index 8f8020605f30..e189164e72d0 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c @@ -75,9 +75,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -160,29 +158,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -230,21 +228,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c index 8e4ce706cafd..85ce5ab96b96 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c @@ -74,9 +74,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni( c7 = c6; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -159,29 +157,29 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -227,21 +225,21 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni-prfm.c index aba2e0520f9b..f1645a21d211 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni-prfm.c @@ -79,9 +79,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -171,32 +169,32 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni_prfm( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -228,23 +226,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni.c index 0491d8eb191f..bf3d7ffa782f 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c4-minmax-avx512vnni.c @@ -78,9 +78,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -170,32 +168,32 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni( size_t k = kc; while (k >= 8 * sizeof(int8_t)) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); - const __m512i va0x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + const __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0 + 4)); a0 += 8; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); - const __m512i va1x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + const __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1 + 4)); a1 += 8; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); - const __m512i va2x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + const __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2 + 4)); a2 += 8; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); - const __m512i va3x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + const __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3 + 4)); a3 += 8; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); - const __m512i va4x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); + const __m512i va4x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a4 + 4)); a4 += 8; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); - const __m512i va5x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); + const __m512i va5x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a5 + 4)); a5 += 8; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); - const __m512i va6x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); + const __m512i va6x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a6 + 4)); a6 += 8; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); - const __m512i va7x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); + const __m512i va7x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a7 + 4)); a7 += 8; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); - const __m512i va8x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); + const __m512i va8x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a8 + 4)); a8 += 8; const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); @@ -225,23 +223,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni( } if (k != 0) { - const __m512i va0x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a0)), vsign_mask); + const __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; - const __m512i va1x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a1)), vsign_mask); + const __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; - const __m512i va2x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a2)), vsign_mask); + const __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; - const __m512i va3x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a3)), vsign_mask); + const __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; - const __m512i va4x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a4)), vsign_mask); + const __m512i va4x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a4)); a4 += 4; - const __m512i va5x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a5)), vsign_mask); + const __m512i va5x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a5)); a5 += 4; - const __m512i va6x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a6)), vsign_mask); + const __m512i va6x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a6)); a6 += 4; - const __m512i va7x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a7)), vsign_mask); + const __m512i va7x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a7)); a7 += 4; - const __m512i va8x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a8)), vsign_mask); + const __m512i va8x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a8)); a8 += 4; const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni-prfm.c index 16367a3a2e67..401d1d9070bf 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni-prfm.c @@ -79,9 +79,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni_prfm( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -172,32 +170,32 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -251,23 +249,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni_prfm( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni.c index b067294ccf92..31d1f44d89c7 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x16c8-minmax-avx512vnni.c @@ -78,9 +78,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni( c8 = c7; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -171,32 +169,32 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m512i va0x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); + const __m512i va0x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m512i va1x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); + const __m512i va1x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m512i va2x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); + const __m512i va2x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m512i va3x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); + const __m512i va3x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m512i va4x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); + const __m512i va4x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m512i va5x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); + const __m512i va5x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m512i va6x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); + const __m512i va6x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m512i va7x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); + const __m512i va7x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m512i va8x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); + const __m512i va8x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m512i vb01234567x01234567 = _mm512_load_si512(w); @@ -246,23 +244,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni( } if (k != 0) { - const __m512i va0x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m512i va0x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m512i va1x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m512i va1x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m512i va2x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m512i va2x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m512i va3x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m512i va3x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m512i va4x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m512i va4x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m512i va5x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m512i va5x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m512i va6x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m512i va6x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m512i va7x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m512i va7x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m512i va8x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m512i va8x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m512i vb01234567x01234567 = _mm512_load_si512(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c index e1398deda7c2..3203bf47690b 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni-prfm.c @@ -79,9 +79,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -172,32 +170,32 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -249,23 +247,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni.c b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni.c index 7d7c68fc6590..b5ee9e481ee9 100644 --- a/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni.c +++ b/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-9x8c8-minmax-avx256vnni.c @@ -78,9 +78,7 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni( c8 = c7; } - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); @@ -171,32 +169,32 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni( size_t k = kc; while (k >= 16 * sizeof(int8_t)) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); - const __m256i va0x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); + const __m256i va0x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0 + 8)); a0 += 16; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); - const __m256i va1x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); + const __m256i va1x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1 + 8)); a1 += 16; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); - const __m256i va2x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); + const __m256i va2x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2 + 8)); a2 += 16; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); - const __m256i va3x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); + const __m256i va3x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3 + 8)); a3 += 16; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); - const __m256i va4x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); + const __m256i va4x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4 + 8)); a4 += 16; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); - const __m256i va5x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); + const __m256i va5x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5 + 8)); a5 += 16; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); - const __m256i va6x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); + const __m256i va6x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6 + 8)); a6 += 16; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); - const __m256i va7x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); + const __m256i va7x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7 + 8)); a7 += 16; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); - const __m256i va8x89ABCDEF = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); + const __m256i va8x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8 + 8)); a8 += 16; const __m256i vb01234567x0123 = _mm256_load_si256(w); @@ -246,23 +244,23 @@ void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni( } if (k != 0) { - const __m256i va0x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)), vsign_mask); + const __m256i va0x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a0)); a0 += 8; - const __m256i va1x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)), vsign_mask); + const __m256i va1x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a1)); a1 += 8; - const __m256i va2x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)), vsign_mask); + const __m256i va2x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a2)); a2 += 8; - const __m256i va3x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)), vsign_mask); + const __m256i va3x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a3)); a3 += 8; - const __m256i va4x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)), vsign_mask); + const __m256i va4x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a4)); a4 += 8; - const __m256i va5x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)), vsign_mask); + const __m256i va5x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a5)); a5 += 8; - const __m256i va6x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)), vsign_mask); + const __m256i va6x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a6)); a6 += 8; - const __m256i va7x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)), vsign_mask); + const __m256i va7x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a7)); a7 += 8; - const __m256i va8x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)), vsign_mask); + const __m256i va8x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a8)); a8 += 8; const __m256i vb01234567x0123 = _mm256_load_si256(w); diff --git a/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c4-mstep4-aarch64-neondot.c b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c4-mstep4-aarch64-neondot.c new file mode 100644 index 000000000000..39e8208a2508 --- /dev/null +++ b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c4-mstep4-aarch64-neondot.c @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the `kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod` +// GEMM microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod( + m, n, k, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + assert( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`." && + 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c8-mstep4-neoni8mm.c b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c8-mstep4-neoni8mm.c new file mode 100644 index 000000000000..fa2374d0b07e --- /dev/null +++ b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-16x4c8-mstep4-neoni8mm.c @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the `kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm` +// GEMM microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm( + m, n, k, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + assert( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`." && 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c4-aarch64-neondot.c b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c4-aarch64-neondot.c new file mode 100644 index 000000000000..7bd526e3e61e --- /dev/null +++ b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c4-aarch64-neondot.c @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the +// `kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod` GEMM +// microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod( + m, n, k, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + assert( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`." && + 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c8-aarch64-neondot.c b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c8-aarch64-neondot.c new file mode 100644 index 000000000000..859c444ab379 --- /dev/null +++ b/src/qp8-f32-qc8w-gemm/qp8-f32-qc8w-gemm-minmax-1x4c8-aarch64-neondot.c @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include "xnnpack/microparams.h" + +#if XNN_ENABLE_KLEIDIAI + // Keep this line indented to avoid it being pulled out of the #ifdef when the + // sources are amalgamated. + #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h" +#endif // XNN_ENABLE_KLEIDIAI + +// Wraps the +// `kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod` GEMM +// microkernel with a name that is compatible with our tooling. +void xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot( + size_t m, size_t n, size_t k, const void* lhs_packed, + const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]) { +#if XNN_ENABLE_KLEIDIAI + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod( + m, n, k, lhs_packed, rhs_packed, dst, dst_stride_row, dst_stride_col, + minmax_params->scalar.min, minmax_params->scalar.max); +#else + assert( + "Calling KleidiAI microkernel wrapper, but XNNPACK was compiled without " + "`XNN_ENABLE_KLEIDIAI`." && 0); +#endif // XNN_ENABLE_KLEIDIAI +} diff --git a/src/qs8-gemm/MRx16c4-avx512vnni.c.in b/src/qs8-gemm/MRx16c4-avx512vnni.c.in index 74ad7874bfbe..fbd0c1f68ed1 100644 --- a/src/qs8-gemm/MRx16c4-avx512vnni.c.in +++ b/src/qs8-gemm/MRx16c4-avx512vnni.c.in @@ -78,11 +78,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ c${M} = c${M-1}; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); + $if DATATYPE not in ["QD8", "QC4"]: + const __m512i vsign_mask = _mm512_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8", "QC4"]: $for M in range(MR): - const __m512i vinput_zero_point${M} = _mm512_set1_epi32((int) quantization_params[${M}].zero_point + 128); + const __m512i vinput_zero_point${M} = _mm512_set1_epi32((int) quantization_params[${M}].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -123,8 +124,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ size_t k = kc; while (k >= 8 * sizeof(int8_t)) { $for M in range(MR): - const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); - const __m512i va${M}x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M} + 4)), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a${M})); + const __m512i va${M}x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a${M} + 4)); + $else: + const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); + const __m512i va${M}x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M} + 4)), vsign_mask); a${M} += 8; $if DATATYPE == "QC4": @@ -164,7 +169,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4__ if (k != 0) { $for M in range(MR): - const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a${M})); + $else: + const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); a${M} += 4; $if DATATYPE == "QC4": diff --git a/src/qs8-gemm/MRx16c8-avx512vnni.c.in b/src/qs8-gemm/MRx16c8-avx512vnni.c.in index 02149624ed25..c2a1fb58e83c 100644 --- a/src/qs8-gemm/MRx16c8-avx512vnni.c.in +++ b/src/qs8-gemm/MRx16c8-avx512vnni.c.in @@ -103,14 +103,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ assert(kc % bl == 0); assert(bl % 32 == 0); - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8", "QC4", "QB4"]: $for M in range(MR): $if DATATYPE == "QB4": - const __m512 vinput_zero_point${M} = _mm512_set1_ps((float) quantization_params[${M}].zero_point + 128); + const __m512 vinput_zero_point${M} = _mm512_set1_ps((float) quantization_params[${M}].zero_point); $else: - const __m512i vinput_zero_point${M} = _mm512_set1_epi32((int) quantization_params[${M}].zero_point + 128); + const __m512i vinput_zero_point${M} = _mm512_set1_epi32((int) quantization_params[${M}].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -125,6 +123,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ const __m512i vshl4 = _mm512_set1_epi64(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + const __m512i vsign_mask = _mm512_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE != "QC8": const __m512 vscale = _mm512_set1_ps(params->${PARAMS_STRUCT}.scale); // XNN_FORCE_REALIZATION(vscale); @@ -173,8 +173,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ size_t k = kc; ${_}while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - ${_}const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); - ${_}const __m512i va${M}x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); + $if DATATYPE in ["QD8", "QC4", "QB4"]: + ${_}const __m512i va${M}x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})); + ${_}const __m512i va${M}x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)); + $else: + ${_}const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); + ${_}const __m512i va${M}x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); ${_}a${M} += 16; $if DATATYPE in ["QC4", "QB4"]: @@ -231,7 +235,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ ${_}if (k != 0) { $for M in range(MR): - ${_}const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); + $if DATATYPE in ["QD8", "QC4", "QB4"]: + ${_}const __m512i va${M}x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})); + $else: + ${_}const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); ${_}a${M} += 8; $if DATATYPE in ["QC4", "QB4"]: diff --git a/src/qs8-gemm/MRx4c8-ssevnni.c.in b/src/qs8-gemm/MRx4c8-ssevnni.c.in index 589167df1cc7..26d2e281de08 100644 --- a/src/qs8-gemm/MRx4c8-ssevnni.c.in +++ b/src/qs8-gemm/MRx4c8-ssevnni.c.in @@ -84,11 +84,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ c${M} = c${M-1}; } - const __m128i vsign_mask = _mm_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: $for M in range(MR): - const __m128i vinput_zero_point${M} = _mm_set1_epi32((int) quantization_params[${M}].zero_point + 128); + const __m128i vinput_zero_point${M} = _mm_set1_epi32((int) quantization_params[${M}].zero_point); $if "F16" in DATATYPE: const __m128 voutput_min = _mm_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m128 voutput_max = _mm_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -105,6 +103,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ const __m128i vshl4 = _mm_set1_epi64x(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + const __m128i vsign_mask = _mm_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); const __m128 voutput_max_less_zero_point = _mm_set1_ps((int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_zero_point = _mm_set1_epi32(params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_set1_epi8(params->${PARAMS_STRUCT}.output_min); @@ -143,8 +143,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ size_t k = kc; while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - const __m128i va${M}x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask); - const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + const __m128i va${M}x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})); + const __m128i va${M}x89ABCDEF = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)); + $else: + const __m128i va${M}x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask); + const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); a${M} += 16; $if DATATYPE in ["QC4_F16", "QC4_F32"]: @@ -199,7 +203,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x4c8__$ if (k != 0) { $for M in range(MR): - const __m128i va${M}x01234567 = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask); + $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + const __m128i va${M}x01234567 = _mm_set1_epi64x((int64_t) unaligned_load_u64(a${M})); + $else: + const __m128i va${M}x89ABCDEF = _mm_xor_si128(_mm_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); a${M} += 8; $if DATATYPE in ["QC4_F16", "QC4_F32"]: diff --git a/src/qs8-gemm/MRx8c4-avxvnni.c.in b/src/qs8-gemm/MRx8c4-avxvnni.c.in index 0d7788094517..da4ba4cad843 100644 --- a/src/qs8-gemm/MRx8c4-avxvnni.c.in +++ b/src/qs8-gemm/MRx8c4-avxvnni.c.in @@ -83,11 +83,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c4__$ c${M} = c${M-1}; } - const __m256i vsign_mask =_mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8", "QC4"]: $for M in range(MR): - const __m256i vinput_zero_point${M} = _mm256_set1_epi32((int) quantization_params[${M}].zero_point + 128); + const __m256i vinput_zero_point${M} = _mm256_set1_epi32((int) quantization_params[${M}].zero_point); const __m256 voutput_min = _mm256_set1_ps(params->scalar.min); const __m256 voutput_max = _mm256_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -99,6 +97,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c4__$ const __m256i vshl4 = _mm256_set1_epi64x(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + const __m256i vsign_mask =_mm256_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE != "QC8": const __m256 vscale = _mm256_set1_ps(params->${PARAMS_STRUCT}.scale); // XNN_FORCE_REALIZATION(vscale); @@ -128,13 +128,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c4__$ while (k >= ${UNROLL * 4} * sizeof(${XINT8_T})) { $for M in range(MR): $for K in range(UNROLL): - __m256i va${M}x${K}x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a${M} + ${4 * K})); + $if DATATYPE in ["QD8", "QC4"]: + __m256i va${M}x${K}x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a${M} + ${4 * K})); + $else: + va${M}x${K}x0123 = _mm256_xor_si256(va${M}x${K}x0123, vsign_mask); a${M} += ${4 * UNROLL}; - $for M in range(MR): - $for K in range(UNROLL): - va${M}x${K}x0123 = _mm256_xor_si256(va${M}x${K}x0123, vsign_mask); - $if DATATYPE in ["QS8", "QD8"]: $for K in range(UNROLL): const __m256i vb${K}x01234567 = _mm256_load_si256((const __m256i*) ((const ${XINT8_T}*) w + ${32 * K})); @@ -170,12 +169,12 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c4__$ ${VACC(M,K1)}x01234567 = _mm256_add_epi32(${VACC(M,K1)}x01234567, ${VACC(M,K2)}x01234567); while (k != 0) { $for M in range(MR): - __m256i va${M}x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a${M})); + $if DATATYPE in ["QD8", "QC4"]: + __m256i va${M}x0123 = _mm256_set1_epi32((int) unaligned_load_u32(a${M})); + $else: + va${M}x0123 = _mm256_xor_si256(va${M}x0123, vsign_mask); a${M} += 4; - $for M in range(MR): - va${M}x0123 = _mm256_xor_si256(va${M}x0123, vsign_mask); - $if DATATYPE in ["QS8", "QD8"]: const __m256i vb01234567 = _mm256_load_si256(w); $elif DATATYPE in ["QC4"]: diff --git a/src/qs8-gemm/MRx8c8-avx512vnni.c.in b/src/qs8-gemm/MRx8c8-avx512vnni.c.in index 260bfb181f0a..945830661c1c 100644 --- a/src/qs8-gemm/MRx8c8-avx512vnni.c.in +++ b/src/qs8-gemm/MRx8c8-avx512vnni.c.in @@ -75,11 +75,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ c${M} = c${M-1}; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8", "QC4"]: $for M in range(MR): - const __m512i vinput_zero_point${M} = _mm512_set1_epi32((int) quantization_params[${M}].zero_point + 128); + const __m512i vinput_zero_point${M} = _mm512_set1_epi32((int) quantization_params[${M}].zero_point); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); // XNN_FORCE_REALIZATION(voutput_min); @@ -91,6 +89,8 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ const __m512i vshl4 = _mm512_set1_epi64(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + const __m128i vsign_mask = _mm_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE != "QC8": const __m512 vscale = _mm512_set1_ps(params->${PARAMS_STRUCT}.scale); // XNN_FORCE_REALIZATION(vscale); @@ -122,8 +122,13 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ size_t k = kc; while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); - const __m512i va${M}x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})); + const __m512i va${M}x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)); + $else: + const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); + const __m512i va${M}x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); + a${M} += 16; $if DATATYPE == "QC4": @@ -173,7 +178,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8__ if (k != 0) { $for M in range(MR): - const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})); + $else: + const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); a${M} += 8; $if DATATYPE == "QC4": diff --git a/src/qs8-gemm/MRx8c8-avxvnni.c.in b/src/qs8-gemm/MRx8c8-avxvnni.c.in index 924f01b650c1..a3bc85d3767d 100644 --- a/src/qs8-gemm/MRx8c8-avxvnni.c.in +++ b/src/qs8-gemm/MRx8c8-avxvnni.c.in @@ -82,15 +82,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__$ c${M} = c${M-1}; } - $if VARIANT != "AVXVNNIINT8": - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: $for M in range(MR): - $if VARIANT == "AVXVNNIINT8": - const __m256i vinput_zero_point${M} = _mm256_set1_epi32((int) quantization_params[${M}].zero_point); - $else: - const __m256i vinput_zero_point${M} = _mm256_set1_epi32((int) quantization_params[${M}].zero_point + 128); + const __m256i vinput_zero_point${M} = _mm256_set1_epi32((int) quantization_params[${M}].zero_point); $if "F16" in DATATYPE: const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); const __m256 voutput_max = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.max)); @@ -109,6 +103,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__$ const __m256i vshl4 = _mm256_set1_epi64x(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + $if VARIANT != "AVXVNNIINT8": + const __m256i vsign_mask = _mm256_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); const __m256 voutput_max_less_zero_point = _mm256_set1_ps((int32_t) params->${PARAMS_STRUCT}.output_max - (int32_t) params->${PARAMS_STRUCT}.output_zero_point); const __m256i voutput_zero_point = _mm256_set1_epi32(params->${PARAMS_STRUCT}.output_zero_point); const __m128i voutput_min = _mm_set1_epi8(params->${PARAMS_STRUCT}.output_min); @@ -137,7 +134,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__$ size_t k = kc; while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - $if VARIANT == "AVXVNNIINT8": + $if VARIANT == "AVXVNNIINT8" or DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: const __m256i va${M}x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M})); const __m256i va${M}x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)); $else: @@ -197,7 +194,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__$ if (k != 0) { $for M in range(MR): - $if VARIANT == "AVXVNNIINT8": + $if VARIANT == "AVXVNNIINT8" or DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: const __m256i va${M}x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M})); $else: const __m256i va${M}x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask); diff --git a/src/qs8-gemm/c4-avx512amx.c.in b/src/qs8-gemm/c4-avx512amx.c.in index 06c64b0f102a..a5f101c352bf 100644 --- a/src/qs8-gemm/c4-avx512amx.c.in +++ b/src/qs8-gemm/c4-avx512amx.c.in @@ -10,6 +10,11 @@ $assert 1 <= MR <= 16 $assert REQUANTIZATION == "FP32" or not REQUANTIZATION $assert DATATYPE in ["QD8_F32", "QD8_F16", "QC8", "QC4_F32"] #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -57,10 +62,9 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - $for N in range(0, NR, 16): - __attribute__((aligned(64))) int32_t res${N // 16}[${MR} * 16]; $if DATATYPE in ["QC4_F32", "QC4_F16"]: __attribute__((aligned(64))) int8_t weight_buffer[16 * 64]; + __attribute__((aligned(64))) int32_t res[${NR // 16}][${MR} * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -79,19 +83,19 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -202,21 +206,33 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR}c k -= kremainder * sizeof(int8_t); } - // Add tile to bias + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) $for N in range(0, NR, 16): - _tile_stored(${N // 16}, res${N // 16}, 64); + _tile_stored(${N // 16}, &res[${N // 16}][0], 64); + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: $for M in range(MR): $for N in range(0, NR, 16): __m512i vacc${M}x${ABC[N:N+16]} = _mm512_mullo_epi32(vksum${ABC[N:N+16]}, _mm512_set1_epi32((int) quantization_params[${M}].zero_point)); + // Add tile to bias $for M in range(MR): $for N in range(0, NR, 16): - vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vacc${M}x${ABC[N:N+16]}, _mm512_load_epi32(res${N // 16} + ${M * 16})); + vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vacc${M}x${ABC[N:N+16]}, _mm512_load_epi32(&res[${N // 16}][0] + ${M * 16})); $else: + // Add tile to bias $for M in range(MR): $for N in range(0, NR, 16): - __m512i vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vksum${ABC[N:N+16]}, _mm512_load_epi32(res${N // 16} + ${M * 16})); + __m512i vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vksum${ABC[N:N+16]}, _mm512_load_epi32(&res[${N // 16}][0] + ${M * 16})); $if DATATYPE in ["QC4_F32", "QC4_F16"]: $for M in range(MR): diff --git a/src/qs8-igemm/MRx16c4-avx512vnni.c.in b/src/qs8-igemm/MRx16c4-avx512vnni.c.in index 6cd913f15662..9d9b1182e5eb 100644 --- a/src/qs8-igemm/MRx16c4-avx512vnni.c.in +++ b/src/qs8-igemm/MRx16c4-avx512vnni.c.in @@ -74,10 +74,8 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ c${M} = c${M-1}; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8", "QC4"]: - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -92,6 +90,8 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ const __m512i vshl4 = _mm512_set1_epi64(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + const __m512i vsign_mask = _mm512_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE != "QC8": const __m512 vscale = _mm512_set1_ps(params->${PARAMS_STRUCT}.scale); // XNN_FORCE_REALIZATION(vscale); @@ -128,8 +128,12 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ size_t k = kc; while (k >= 8 * sizeof(int8_t)) { $for M in range(MR): - const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); - const __m512i va${M}x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M} + 4)), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a${M})); + const __m512i va${M}x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a${M} + 4)); + $else: + const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); + const __m512i va${M}x4567 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M} + 4)), vsign_mask); a${M} += 8; $if DATATYPE == "QC4": @@ -163,7 +167,10 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c4_ if (k != 0) { $for M in range(MR): - const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a${M})); + $else: + const __m512i va${M}x0123 = _mm512_xor_epi32(_mm512_set1_epi32((int) unaligned_load_u32(a${M})), vsign_mask); a${M} += 4; $if DATATYPE == "QC4": diff --git a/src/qs8-igemm/MRx16c8-avx512vnni.c.in b/src/qs8-igemm/MRx16c8-avx512vnni.c.in index ff02074ee031..da2ccf965146 100644 --- a/src/qs8-igemm/MRx16c8-avx512vnni.c.in +++ b/src/qs8-igemm/MRx16c8-avx512vnni.c.in @@ -74,10 +74,8 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ c${M} = c${M-1}; } - const __m512i vsign_mask = _mm512_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8", "QC4"]: - const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point); const __m512 vinput_inv_scale = _mm512_set1_ps(quantization_params->inv_scale); const __m512 voutput_min = _mm512_set1_ps(params->scalar.min); const __m512 voutput_max = _mm512_set1_ps(params->scalar.max); @@ -92,6 +90,8 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ const __m512i vshl4 = _mm512_set1_epi64(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + const __m512i vsign_mask = _mm512_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE != "QC8": const __m512 vscale = _mm512_set1_ps(params->${PARAMS_STRUCT}.scale); // XNN_FORCE_REALIZATION(vscale); @@ -134,8 +134,12 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ size_t k = kc; while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); - const __m512i va${M}x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})); + const __m512i va${M}x89ABCDEF = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)); + $else: + const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); + const __m512i va${M}x89ABCDEF = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M} + 8)), vsign_mask); a${M} += 16; $if DATATYPE == "QC4": @@ -185,7 +189,10 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x16c8_ if (k != 0) { $for M in range(MR): - const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); + $if DATATYPE in ["QD8", "QC4"]: + const __m512i va${M}x01234567 = _mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})); + $else: + const __m512i va${M}x01234567 = _mm512_xor_epi64(_mm512_set1_epi64((int64_t) unaligned_load_u64(a${M})), vsign_mask); a${M} += 8; $if DATATYPE == "QC4": diff --git a/src/qs8-igemm/MRx8c8-avxvnni.c.in b/src/qs8-igemm/MRx8c8-avxvnni.c.in index 6955141db3e8..d0e734e9cc0b 100644 --- a/src/qs8-igemm/MRx8c8-avxvnni.c.in +++ b/src/qs8-igemm/MRx8c8-avxvnni.c.in @@ -81,14 +81,11 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__ c${M} = c${M-1}; } - $if VARIANT != "AVXVNNIINT8": - const __m256i vsign_mask = _mm256_set1_epi8(0x80); - XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: $if VARIANT == "AVXVNNIINT8": const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); $else: - const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point + 128); + const __m256i vinput_zero_point = _mm256_set1_epi32((int) quantization_params->zero_point); const __m256 vinput_inv_scale = _mm256_set1_ps(quantization_params->inv_scale); $if "F16" in DATATYPE: const __m256 voutput_min = _mm256_cvtph_ps(_mm_set1_epi16(*(const uint16_t*) ¶ms->scalar.min)); @@ -107,6 +104,9 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__ const __m256i vshl4 = _mm256_set1_epi64(0x01020408); XNN_FORCE_REALIZATION(vshl4); $else: + $if VARIANT != "AVXVNNIINT8": + const __m256i vsign_mask = _mm256_set1_epi8(0x80); + XNN_FORCE_REALIZATION(vsign_mask); $if DATATYPE != "QC8": const __m256 vscale = _mm256_load_ps(params->${PARAMS_STRUCT}.scale); // XNN_FORCE_REALIZATION(vscale); @@ -149,7 +149,7 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__ size_t k = kc; while (k >= 16 * sizeof(int8_t)) { $for M in range(MR): - $if VARIANT == "AVXVNNIINT8": + $if VARIANT == "AVXVNNIINT8" or DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: const __m256i va${M}x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M})); const __m256i va${M}x89ABCDEF = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M} + 8)); $else: @@ -202,7 +202,7 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x8c8__ if (k != 0) { $for M in range(MR): - $if VARIANT == "AVXVNNIINT8": + $if VARIANT == "AVXVNNIINT8" or DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: const __m256i va${M}x01234567 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M})); $else: const __m256i va${M}x01234567 = _mm256_xor_si256(_mm256_set1_epi64x((int64_t) unaligned_load_u64(a${M})), vsign_mask); diff --git a/src/qs8-igemm/c4-avx512amx.c.in b/src/qs8-igemm/c4-avx512amx.c.in index 89c22aeec41c..1a86b134a3d4 100644 --- a/src/qs8-igemm/c4-avx512amx.c.in +++ b/src/qs8-igemm/c4-avx512amx.c.in @@ -10,6 +10,11 @@ $assert 1 <= MR <= 16 $assert REQUANTIZATION == "FP32" or not REQUANTIZATION $assert DATATYPE in ["QD8_F32", "QD8_F16", "QC8"] #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -63,8 +68,7 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR} // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[${MR} * 16]; - $for N in range(0, NR, 16): - __attribute__((aligned(64))) int32_t res${N // 16}[${MR} * 16]; + __attribute__((aligned(64))) int32_t res[${NR // 16}][${MR} * 16]; kc = round_up_po2(kc, 4 * sizeof(${XINT8_T})); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -84,19 +88,19 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR} // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -219,21 +223,33 @@ void xnn_${DATATYPE_SPEC}_igemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x${NR} p -= ${MR} * sizeof(void*); } while (p != 0); - // Add tile to bias + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) $for N in range(0, NR, 16): - _tile_stored(${N // 16}, res${N // 16}, 64); + _tile_stored(${N // 16}, &res[${N // 16}][0], 64); + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. $if DATATYPE in ["QD8_F16", "QD8_F32", "QC4_F16", "QC4_F32"]: + // Add tile to bias $for M in range(MR): $for N in range(0, NR, 16): __m512i vacc${M}x${ABC[N:N+16]} = _mm512_mullo_epi32(vksum${ABC[N:N+16]}, _mm512_set1_epi32((int) quantization_params->zero_point)); $for M in range(MR): $for N in range(0, NR, 16): - vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vacc${M}x${ABC[N:N+16]}, _mm512_load_epi32(res${N // 16} + ${M * 16})); + vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vacc${M}x${ABC[N:N+16]}, _mm512_load_epi32(&res[${N // 16}][0] + ${M * 16})); $else: + // Add tile to bias $for M in range(MR): $for N in range(0, NR, 16): - __m512i vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vksum${ABC[N:N+16]}, _mm512_load_epi32(res${N // 16} + ${M * 16})); + __m512i vacc${M}x${ABC[N:N+16]} = _mm512_add_epi32(vksum${ABC[N:N+16]}, _mm512_load_epi32(&res[${N // 16}][0] + ${M * 16})); $if DATATYPE == "QC4_F32": $for M in range(MR): diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c index f537e3050a6a..40001255d830 100644 --- a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,32 +51,19 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( const int8_t* w13 = w12 + kc; const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -254,22 +265,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w14 + 448); xnn_prefetch_to_l1((const int8_t*) w15 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -392,28 +403,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -462,25 +467,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -541,46 +531,263 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( if XNN_UNPREDICTABLE(n < 16) { w15 = w14; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); - xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); xnn_prefetch_to_l1((const int8_t*) w8 + 64); - xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); xnn_prefetch_to_l1((const int8_t*) w9 + 64); - xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); xnn_prefetch_to_l1((const int8_t*) w10 + 64); - xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); xnn_prefetch_to_l1((const int8_t*) w11 + 64); - xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); xnn_prefetch_to_l1((const int8_t*) w12 + 64); - xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); xnn_prefetch_to_l1((const int8_t*) w13 + 64); - xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); xnn_prefetch_to_l1((const int8_t*) w14 + 64); - xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -648,28 +855,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -713,9 +914,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c index d364f82053c7..72dabcf55dc6 100644 --- a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,32 +50,19 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); @@ -125,22 +136,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -247,28 +258,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -317,25 +322,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -397,13 +387,134 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( w15 = w14; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -455,28 +566,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -520,9 +625,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c index 859f7ddc820c..259268757636 100644 --- a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,32 +51,19 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( const int8_t* w13 = w12 + kc; const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -254,22 +265,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w14 + 448); xnn_prefetch_to_l1((const int8_t*) w15 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -392,28 +403,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -462,25 +467,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -541,46 +531,263 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( if XNN_UNPREDICTABLE(n < 16) { w15 = w14; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); - xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); xnn_prefetch_to_l1((const int8_t*) w8 + 64); - xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); xnn_prefetch_to_l1((const int8_t*) w9 + 64); - xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); xnn_prefetch_to_l1((const int8_t*) w10 + 64); - xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); xnn_prefetch_to_l1((const int8_t*) w11 + 64); - xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); xnn_prefetch_to_l1((const int8_t*) w12 + 64); - xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); xnn_prefetch_to_l1((const int8_t*) w13 + 64); - xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); xnn_prefetch_to_l1((const int8_t*) w14 + 64); - xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -648,28 +855,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -713,9 +914,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c index 7fcca8d44b39..d19ca224744e 100644 --- a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,32 +50,19 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); @@ -125,22 +136,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -247,28 +258,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -317,25 +322,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -397,13 +387,134 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( w15 = w14; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -455,28 +566,22 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -520,9 +625,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni-prfm.c new file mode 100644 index 000000000000..0f4532523c75 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni-prfm.c @@ -0,0 +1,1852 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/c4-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + +XNN_INLINE static uint32_t safe_load_u32(const void* src, size_t k) { + uint32_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < k; ++i) { + value |= (uint32_t) bytes[i] << (i * 8); + } + return value; +} + + +void xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 64); + assert(kr == 4); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); + + do { + // NC main loop multiple of 64 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 64; n -= 64) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + const int8_t* w16 = w15 + kc; + const int8_t* w17 = w16 + kc; + const int8_t* w18 = w17 + kc; + const int8_t* w19 = w18 + kc; + const int8_t* w20 = w19 + kc; + const int8_t* w21 = w20 + kc; + const int8_t* w22 = w21 + kc; + const int8_t* w23 = w22 + kc; + const int8_t* w24 = w23 + kc; + const int8_t* w25 = w24 + kc; + const int8_t* w26 = w25 + kc; + const int8_t* w27 = w26 + kc; + const int8_t* w28 = w27 + kc; + const int8_t* w29 = w28 + kc; + const int8_t* w30 = w29 + kc; + const int8_t* w31 = w30 + kc; + const int8_t* w32 = w31 + kc; + const int8_t* w33 = w32 + kc; + const int8_t* w34 = w33 + kc; + const int8_t* w35 = w34 + kc; + const int8_t* w36 = w35 + kc; + const int8_t* w37 = w36 + kc; + const int8_t* w38 = w37 + kc; + const int8_t* w39 = w38 + kc; + const int8_t* w40 = w39 + kc; + const int8_t* w41 = w40 + kc; + const int8_t* w42 = w41 + kc; + const int8_t* w43 = w42 + kc; + const int8_t* w44 = w43 + kc; + const int8_t* w45 = w44 + kc; + const int8_t* w46 = w45 + kc; + const int8_t* w47 = w46 + kc; + const int8_t* w48 = w47 + kc; + const int8_t* w49 = w48 + kc; + const int8_t* w50 = w49 + kc; + const int8_t* w51 = w50 + kc; + const int8_t* w52 = w51 + kc; + const int8_t* w53 = w52 + kc; + const int8_t* w54 = w53 + kc; + const int8_t* w55 = w54 + kc; + const int8_t* w56 = w55 + kc; + const int8_t* w57 = w56 + kc; + const int8_t* w58 = w57 + kc; + const int8_t* w59 = w58 + kc; + const int8_t* w60 = w59 + kc; + const int8_t* w61 = w60 + kc; + const int8_t* w62 = w61 + kc; + const int8_t* w63 = w62 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + const __m256i vb16 = _mm256_loadu_si256((const __m256i*) (b + 16)); + const __m256i vb24 = _mm256_loadu_si256((const __m256i*) (b + 24)); + const __m256i vb32 = _mm256_loadu_si256((const __m256i*) (b + 32)); + const __m256i vb40 = _mm256_loadu_si256((const __m256i*) (b + 40)); + const __m256i vb48 = _mm256_loadu_si256((const __m256i*) (b + 48)); + const __m256i vb56 = _mm256_loadu_si256((const __m256i*) (b + 56)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + _mm256_storeu_si256((__m256i*) (out + 64), vb16); + _mm256_storeu_si256((__m256i*) (out + 96), vb24); + _mm256_storeu_si256((__m256i*) (out + 128), vb32); + _mm256_storeu_si256((__m256i*) (out + 160), vb40); + _mm256_storeu_si256((__m256i*) (out + 192), vb48); + _mm256_storeu_si256((__m256i*) (out + 224), vb56); + b += 64; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 64), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 96), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 128), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 160), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 192), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 224), _mm256_setzero_si256()); + } + out += 64 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); + xnn_prefetch_to_l1((const int8_t*) w16 + 0); + xnn_prefetch_to_l1((const int8_t*) w16 + 64); + xnn_prefetch_to_l1((const int8_t*) w16 + 128); + xnn_prefetch_to_l1((const int8_t*) w16 + 192); + xnn_prefetch_to_l1((const int8_t*) w16 + 256); + xnn_prefetch_to_l1((const int8_t*) w16 + 320); + xnn_prefetch_to_l1((const int8_t*) w16 + 384); + xnn_prefetch_to_l1((const int8_t*) w17 + 0); + xnn_prefetch_to_l1((const int8_t*) w17 + 64); + xnn_prefetch_to_l1((const int8_t*) w17 + 128); + xnn_prefetch_to_l1((const int8_t*) w17 + 192); + xnn_prefetch_to_l1((const int8_t*) w17 + 256); + xnn_prefetch_to_l1((const int8_t*) w17 + 320); + xnn_prefetch_to_l1((const int8_t*) w17 + 384); + xnn_prefetch_to_l1((const int8_t*) w18 + 0); + xnn_prefetch_to_l1((const int8_t*) w18 + 64); + xnn_prefetch_to_l1((const int8_t*) w18 + 128); + xnn_prefetch_to_l1((const int8_t*) w18 + 192); + xnn_prefetch_to_l1((const int8_t*) w18 + 256); + xnn_prefetch_to_l1((const int8_t*) w18 + 320); + xnn_prefetch_to_l1((const int8_t*) w18 + 384); + xnn_prefetch_to_l1((const int8_t*) w19 + 0); + xnn_prefetch_to_l1((const int8_t*) w19 + 64); + xnn_prefetch_to_l1((const int8_t*) w19 + 128); + xnn_prefetch_to_l1((const int8_t*) w19 + 192); + xnn_prefetch_to_l1((const int8_t*) w19 + 256); + xnn_prefetch_to_l1((const int8_t*) w19 + 320); + xnn_prefetch_to_l1((const int8_t*) w19 + 384); + xnn_prefetch_to_l1((const int8_t*) w20 + 0); + xnn_prefetch_to_l1((const int8_t*) w20 + 64); + xnn_prefetch_to_l1((const int8_t*) w20 + 128); + xnn_prefetch_to_l1((const int8_t*) w20 + 192); + xnn_prefetch_to_l1((const int8_t*) w20 + 256); + xnn_prefetch_to_l1((const int8_t*) w20 + 320); + xnn_prefetch_to_l1((const int8_t*) w20 + 384); + xnn_prefetch_to_l1((const int8_t*) w21 + 0); + xnn_prefetch_to_l1((const int8_t*) w21 + 64); + xnn_prefetch_to_l1((const int8_t*) w21 + 128); + xnn_prefetch_to_l1((const int8_t*) w21 + 192); + xnn_prefetch_to_l1((const int8_t*) w21 + 256); + xnn_prefetch_to_l1((const int8_t*) w21 + 320); + xnn_prefetch_to_l1((const int8_t*) w21 + 384); + xnn_prefetch_to_l1((const int8_t*) w22 + 0); + xnn_prefetch_to_l1((const int8_t*) w22 + 64); + xnn_prefetch_to_l1((const int8_t*) w22 + 128); + xnn_prefetch_to_l1((const int8_t*) w22 + 192); + xnn_prefetch_to_l1((const int8_t*) w22 + 256); + xnn_prefetch_to_l1((const int8_t*) w22 + 320); + xnn_prefetch_to_l1((const int8_t*) w22 + 384); + xnn_prefetch_to_l1((const int8_t*) w23 + 0); + xnn_prefetch_to_l1((const int8_t*) w23 + 64); + xnn_prefetch_to_l1((const int8_t*) w23 + 128); + xnn_prefetch_to_l1((const int8_t*) w23 + 192); + xnn_prefetch_to_l1((const int8_t*) w23 + 256); + xnn_prefetch_to_l1((const int8_t*) w23 + 320); + xnn_prefetch_to_l1((const int8_t*) w23 + 384); + xnn_prefetch_to_l1((const int8_t*) w24 + 0); + xnn_prefetch_to_l1((const int8_t*) w24 + 64); + xnn_prefetch_to_l1((const int8_t*) w24 + 128); + xnn_prefetch_to_l1((const int8_t*) w24 + 192); + xnn_prefetch_to_l1((const int8_t*) w24 + 256); + xnn_prefetch_to_l1((const int8_t*) w24 + 320); + xnn_prefetch_to_l1((const int8_t*) w24 + 384); + xnn_prefetch_to_l1((const int8_t*) w25 + 0); + xnn_prefetch_to_l1((const int8_t*) w25 + 64); + xnn_prefetch_to_l1((const int8_t*) w25 + 128); + xnn_prefetch_to_l1((const int8_t*) w25 + 192); + xnn_prefetch_to_l1((const int8_t*) w25 + 256); + xnn_prefetch_to_l1((const int8_t*) w25 + 320); + xnn_prefetch_to_l1((const int8_t*) w25 + 384); + xnn_prefetch_to_l1((const int8_t*) w26 + 0); + xnn_prefetch_to_l1((const int8_t*) w26 + 64); + xnn_prefetch_to_l1((const int8_t*) w26 + 128); + xnn_prefetch_to_l1((const int8_t*) w26 + 192); + xnn_prefetch_to_l1((const int8_t*) w26 + 256); + xnn_prefetch_to_l1((const int8_t*) w26 + 320); + xnn_prefetch_to_l1((const int8_t*) w26 + 384); + xnn_prefetch_to_l1((const int8_t*) w27 + 0); + xnn_prefetch_to_l1((const int8_t*) w27 + 64); + xnn_prefetch_to_l1((const int8_t*) w27 + 128); + xnn_prefetch_to_l1((const int8_t*) w27 + 192); + xnn_prefetch_to_l1((const int8_t*) w27 + 256); + xnn_prefetch_to_l1((const int8_t*) w27 + 320); + xnn_prefetch_to_l1((const int8_t*) w27 + 384); + xnn_prefetch_to_l1((const int8_t*) w28 + 0); + xnn_prefetch_to_l1((const int8_t*) w28 + 64); + xnn_prefetch_to_l1((const int8_t*) w28 + 128); + xnn_prefetch_to_l1((const int8_t*) w28 + 192); + xnn_prefetch_to_l1((const int8_t*) w28 + 256); + xnn_prefetch_to_l1((const int8_t*) w28 + 320); + xnn_prefetch_to_l1((const int8_t*) w28 + 384); + xnn_prefetch_to_l1((const int8_t*) w29 + 0); + xnn_prefetch_to_l1((const int8_t*) w29 + 64); + xnn_prefetch_to_l1((const int8_t*) w29 + 128); + xnn_prefetch_to_l1((const int8_t*) w29 + 192); + xnn_prefetch_to_l1((const int8_t*) w29 + 256); + xnn_prefetch_to_l1((const int8_t*) w29 + 320); + xnn_prefetch_to_l1((const int8_t*) w29 + 384); + xnn_prefetch_to_l1((const int8_t*) w30 + 0); + xnn_prefetch_to_l1((const int8_t*) w30 + 64); + xnn_prefetch_to_l1((const int8_t*) w30 + 128); + xnn_prefetch_to_l1((const int8_t*) w30 + 192); + xnn_prefetch_to_l1((const int8_t*) w30 + 256); + xnn_prefetch_to_l1((const int8_t*) w30 + 320); + xnn_prefetch_to_l1((const int8_t*) w30 + 384); + xnn_prefetch_to_l1((const int8_t*) w31 + 0); + xnn_prefetch_to_l1((const int8_t*) w31 + 64); + xnn_prefetch_to_l1((const int8_t*) w31 + 128); + xnn_prefetch_to_l1((const int8_t*) w31 + 192); + xnn_prefetch_to_l1((const int8_t*) w31 + 256); + xnn_prefetch_to_l1((const int8_t*) w31 + 320); + xnn_prefetch_to_l1((const int8_t*) w31 + 384); + xnn_prefetch_to_l1((const int8_t*) w32 + 0); + xnn_prefetch_to_l1((const int8_t*) w32 + 64); + xnn_prefetch_to_l1((const int8_t*) w32 + 128); + xnn_prefetch_to_l1((const int8_t*) w32 + 192); + xnn_prefetch_to_l1((const int8_t*) w32 + 256); + xnn_prefetch_to_l1((const int8_t*) w32 + 320); + xnn_prefetch_to_l1((const int8_t*) w32 + 384); + xnn_prefetch_to_l1((const int8_t*) w33 + 0); + xnn_prefetch_to_l1((const int8_t*) w33 + 64); + xnn_prefetch_to_l1((const int8_t*) w33 + 128); + xnn_prefetch_to_l1((const int8_t*) w33 + 192); + xnn_prefetch_to_l1((const int8_t*) w33 + 256); + xnn_prefetch_to_l1((const int8_t*) w33 + 320); + xnn_prefetch_to_l1((const int8_t*) w33 + 384); + xnn_prefetch_to_l1((const int8_t*) w34 + 0); + xnn_prefetch_to_l1((const int8_t*) w34 + 64); + xnn_prefetch_to_l1((const int8_t*) w34 + 128); + xnn_prefetch_to_l1((const int8_t*) w34 + 192); + xnn_prefetch_to_l1((const int8_t*) w34 + 256); + xnn_prefetch_to_l1((const int8_t*) w34 + 320); + xnn_prefetch_to_l1((const int8_t*) w34 + 384); + xnn_prefetch_to_l1((const int8_t*) w35 + 0); + xnn_prefetch_to_l1((const int8_t*) w35 + 64); + xnn_prefetch_to_l1((const int8_t*) w35 + 128); + xnn_prefetch_to_l1((const int8_t*) w35 + 192); + xnn_prefetch_to_l1((const int8_t*) w35 + 256); + xnn_prefetch_to_l1((const int8_t*) w35 + 320); + xnn_prefetch_to_l1((const int8_t*) w35 + 384); + xnn_prefetch_to_l1((const int8_t*) w36 + 0); + xnn_prefetch_to_l1((const int8_t*) w36 + 64); + xnn_prefetch_to_l1((const int8_t*) w36 + 128); + xnn_prefetch_to_l1((const int8_t*) w36 + 192); + xnn_prefetch_to_l1((const int8_t*) w36 + 256); + xnn_prefetch_to_l1((const int8_t*) w36 + 320); + xnn_prefetch_to_l1((const int8_t*) w36 + 384); + xnn_prefetch_to_l1((const int8_t*) w37 + 0); + xnn_prefetch_to_l1((const int8_t*) w37 + 64); + xnn_prefetch_to_l1((const int8_t*) w37 + 128); + xnn_prefetch_to_l1((const int8_t*) w37 + 192); + xnn_prefetch_to_l1((const int8_t*) w37 + 256); + xnn_prefetch_to_l1((const int8_t*) w37 + 320); + xnn_prefetch_to_l1((const int8_t*) w37 + 384); + xnn_prefetch_to_l1((const int8_t*) w38 + 0); + xnn_prefetch_to_l1((const int8_t*) w38 + 64); + xnn_prefetch_to_l1((const int8_t*) w38 + 128); + xnn_prefetch_to_l1((const int8_t*) w38 + 192); + xnn_prefetch_to_l1((const int8_t*) w38 + 256); + xnn_prefetch_to_l1((const int8_t*) w38 + 320); + xnn_prefetch_to_l1((const int8_t*) w38 + 384); + xnn_prefetch_to_l1((const int8_t*) w39 + 0); + xnn_prefetch_to_l1((const int8_t*) w39 + 64); + xnn_prefetch_to_l1((const int8_t*) w39 + 128); + xnn_prefetch_to_l1((const int8_t*) w39 + 192); + xnn_prefetch_to_l1((const int8_t*) w39 + 256); + xnn_prefetch_to_l1((const int8_t*) w39 + 320); + xnn_prefetch_to_l1((const int8_t*) w39 + 384); + xnn_prefetch_to_l1((const int8_t*) w40 + 0); + xnn_prefetch_to_l1((const int8_t*) w40 + 64); + xnn_prefetch_to_l1((const int8_t*) w40 + 128); + xnn_prefetch_to_l1((const int8_t*) w40 + 192); + xnn_prefetch_to_l1((const int8_t*) w40 + 256); + xnn_prefetch_to_l1((const int8_t*) w40 + 320); + xnn_prefetch_to_l1((const int8_t*) w40 + 384); + xnn_prefetch_to_l1((const int8_t*) w41 + 0); + xnn_prefetch_to_l1((const int8_t*) w41 + 64); + xnn_prefetch_to_l1((const int8_t*) w41 + 128); + xnn_prefetch_to_l1((const int8_t*) w41 + 192); + xnn_prefetch_to_l1((const int8_t*) w41 + 256); + xnn_prefetch_to_l1((const int8_t*) w41 + 320); + xnn_prefetch_to_l1((const int8_t*) w41 + 384); + xnn_prefetch_to_l1((const int8_t*) w42 + 0); + xnn_prefetch_to_l1((const int8_t*) w42 + 64); + xnn_prefetch_to_l1((const int8_t*) w42 + 128); + xnn_prefetch_to_l1((const int8_t*) w42 + 192); + xnn_prefetch_to_l1((const int8_t*) w42 + 256); + xnn_prefetch_to_l1((const int8_t*) w42 + 320); + xnn_prefetch_to_l1((const int8_t*) w42 + 384); + xnn_prefetch_to_l1((const int8_t*) w43 + 0); + xnn_prefetch_to_l1((const int8_t*) w43 + 64); + xnn_prefetch_to_l1((const int8_t*) w43 + 128); + xnn_prefetch_to_l1((const int8_t*) w43 + 192); + xnn_prefetch_to_l1((const int8_t*) w43 + 256); + xnn_prefetch_to_l1((const int8_t*) w43 + 320); + xnn_prefetch_to_l1((const int8_t*) w43 + 384); + xnn_prefetch_to_l1((const int8_t*) w44 + 0); + xnn_prefetch_to_l1((const int8_t*) w44 + 64); + xnn_prefetch_to_l1((const int8_t*) w44 + 128); + xnn_prefetch_to_l1((const int8_t*) w44 + 192); + xnn_prefetch_to_l1((const int8_t*) w44 + 256); + xnn_prefetch_to_l1((const int8_t*) w44 + 320); + xnn_prefetch_to_l1((const int8_t*) w44 + 384); + xnn_prefetch_to_l1((const int8_t*) w45 + 0); + xnn_prefetch_to_l1((const int8_t*) w45 + 64); + xnn_prefetch_to_l1((const int8_t*) w45 + 128); + xnn_prefetch_to_l1((const int8_t*) w45 + 192); + xnn_prefetch_to_l1((const int8_t*) w45 + 256); + xnn_prefetch_to_l1((const int8_t*) w45 + 320); + xnn_prefetch_to_l1((const int8_t*) w45 + 384); + xnn_prefetch_to_l1((const int8_t*) w46 + 0); + xnn_prefetch_to_l1((const int8_t*) w46 + 64); + xnn_prefetch_to_l1((const int8_t*) w46 + 128); + xnn_prefetch_to_l1((const int8_t*) w46 + 192); + xnn_prefetch_to_l1((const int8_t*) w46 + 256); + xnn_prefetch_to_l1((const int8_t*) w46 + 320); + xnn_prefetch_to_l1((const int8_t*) w46 + 384); + xnn_prefetch_to_l1((const int8_t*) w47 + 0); + xnn_prefetch_to_l1((const int8_t*) w47 + 64); + xnn_prefetch_to_l1((const int8_t*) w47 + 128); + xnn_prefetch_to_l1((const int8_t*) w47 + 192); + xnn_prefetch_to_l1((const int8_t*) w47 + 256); + xnn_prefetch_to_l1((const int8_t*) w47 + 320); + xnn_prefetch_to_l1((const int8_t*) w47 + 384); + xnn_prefetch_to_l1((const int8_t*) w48 + 0); + xnn_prefetch_to_l1((const int8_t*) w48 + 64); + xnn_prefetch_to_l1((const int8_t*) w48 + 128); + xnn_prefetch_to_l1((const int8_t*) w48 + 192); + xnn_prefetch_to_l1((const int8_t*) w48 + 256); + xnn_prefetch_to_l1((const int8_t*) w48 + 320); + xnn_prefetch_to_l1((const int8_t*) w48 + 384); + xnn_prefetch_to_l1((const int8_t*) w49 + 0); + xnn_prefetch_to_l1((const int8_t*) w49 + 64); + xnn_prefetch_to_l1((const int8_t*) w49 + 128); + xnn_prefetch_to_l1((const int8_t*) w49 + 192); + xnn_prefetch_to_l1((const int8_t*) w49 + 256); + xnn_prefetch_to_l1((const int8_t*) w49 + 320); + xnn_prefetch_to_l1((const int8_t*) w49 + 384); + xnn_prefetch_to_l1((const int8_t*) w50 + 0); + xnn_prefetch_to_l1((const int8_t*) w50 + 64); + xnn_prefetch_to_l1((const int8_t*) w50 + 128); + xnn_prefetch_to_l1((const int8_t*) w50 + 192); + xnn_prefetch_to_l1((const int8_t*) w50 + 256); + xnn_prefetch_to_l1((const int8_t*) w50 + 320); + xnn_prefetch_to_l1((const int8_t*) w50 + 384); + xnn_prefetch_to_l1((const int8_t*) w51 + 0); + xnn_prefetch_to_l1((const int8_t*) w51 + 64); + xnn_prefetch_to_l1((const int8_t*) w51 + 128); + xnn_prefetch_to_l1((const int8_t*) w51 + 192); + xnn_prefetch_to_l1((const int8_t*) w51 + 256); + xnn_prefetch_to_l1((const int8_t*) w51 + 320); + xnn_prefetch_to_l1((const int8_t*) w51 + 384); + xnn_prefetch_to_l1((const int8_t*) w52 + 0); + xnn_prefetch_to_l1((const int8_t*) w52 + 64); + xnn_prefetch_to_l1((const int8_t*) w52 + 128); + xnn_prefetch_to_l1((const int8_t*) w52 + 192); + xnn_prefetch_to_l1((const int8_t*) w52 + 256); + xnn_prefetch_to_l1((const int8_t*) w52 + 320); + xnn_prefetch_to_l1((const int8_t*) w52 + 384); + xnn_prefetch_to_l1((const int8_t*) w53 + 0); + xnn_prefetch_to_l1((const int8_t*) w53 + 64); + xnn_prefetch_to_l1((const int8_t*) w53 + 128); + xnn_prefetch_to_l1((const int8_t*) w53 + 192); + xnn_prefetch_to_l1((const int8_t*) w53 + 256); + xnn_prefetch_to_l1((const int8_t*) w53 + 320); + xnn_prefetch_to_l1((const int8_t*) w53 + 384); + xnn_prefetch_to_l1((const int8_t*) w54 + 0); + xnn_prefetch_to_l1((const int8_t*) w54 + 64); + xnn_prefetch_to_l1((const int8_t*) w54 + 128); + xnn_prefetch_to_l1((const int8_t*) w54 + 192); + xnn_prefetch_to_l1((const int8_t*) w54 + 256); + xnn_prefetch_to_l1((const int8_t*) w54 + 320); + xnn_prefetch_to_l1((const int8_t*) w54 + 384); + xnn_prefetch_to_l1((const int8_t*) w55 + 0); + xnn_prefetch_to_l1((const int8_t*) w55 + 64); + xnn_prefetch_to_l1((const int8_t*) w55 + 128); + xnn_prefetch_to_l1((const int8_t*) w55 + 192); + xnn_prefetch_to_l1((const int8_t*) w55 + 256); + xnn_prefetch_to_l1((const int8_t*) w55 + 320); + xnn_prefetch_to_l1((const int8_t*) w55 + 384); + xnn_prefetch_to_l1((const int8_t*) w56 + 0); + xnn_prefetch_to_l1((const int8_t*) w56 + 64); + xnn_prefetch_to_l1((const int8_t*) w56 + 128); + xnn_prefetch_to_l1((const int8_t*) w56 + 192); + xnn_prefetch_to_l1((const int8_t*) w56 + 256); + xnn_prefetch_to_l1((const int8_t*) w56 + 320); + xnn_prefetch_to_l1((const int8_t*) w56 + 384); + xnn_prefetch_to_l1((const int8_t*) w57 + 0); + xnn_prefetch_to_l1((const int8_t*) w57 + 64); + xnn_prefetch_to_l1((const int8_t*) w57 + 128); + xnn_prefetch_to_l1((const int8_t*) w57 + 192); + xnn_prefetch_to_l1((const int8_t*) w57 + 256); + xnn_prefetch_to_l1((const int8_t*) w57 + 320); + xnn_prefetch_to_l1((const int8_t*) w57 + 384); + xnn_prefetch_to_l1((const int8_t*) w58 + 0); + xnn_prefetch_to_l1((const int8_t*) w58 + 64); + xnn_prefetch_to_l1((const int8_t*) w58 + 128); + xnn_prefetch_to_l1((const int8_t*) w58 + 192); + xnn_prefetch_to_l1((const int8_t*) w58 + 256); + xnn_prefetch_to_l1((const int8_t*) w58 + 320); + xnn_prefetch_to_l1((const int8_t*) w58 + 384); + xnn_prefetch_to_l1((const int8_t*) w59 + 0); + xnn_prefetch_to_l1((const int8_t*) w59 + 64); + xnn_prefetch_to_l1((const int8_t*) w59 + 128); + xnn_prefetch_to_l1((const int8_t*) w59 + 192); + xnn_prefetch_to_l1((const int8_t*) w59 + 256); + xnn_prefetch_to_l1((const int8_t*) w59 + 320); + xnn_prefetch_to_l1((const int8_t*) w59 + 384); + xnn_prefetch_to_l1((const int8_t*) w60 + 0); + xnn_prefetch_to_l1((const int8_t*) w60 + 64); + xnn_prefetch_to_l1((const int8_t*) w60 + 128); + xnn_prefetch_to_l1((const int8_t*) w60 + 192); + xnn_prefetch_to_l1((const int8_t*) w60 + 256); + xnn_prefetch_to_l1((const int8_t*) w60 + 320); + xnn_prefetch_to_l1((const int8_t*) w60 + 384); + xnn_prefetch_to_l1((const int8_t*) w61 + 0); + xnn_prefetch_to_l1((const int8_t*) w61 + 64); + xnn_prefetch_to_l1((const int8_t*) w61 + 128); + xnn_prefetch_to_l1((const int8_t*) w61 + 192); + xnn_prefetch_to_l1((const int8_t*) w61 + 256); + xnn_prefetch_to_l1((const int8_t*) w61 + 320); + xnn_prefetch_to_l1((const int8_t*) w61 + 384); + xnn_prefetch_to_l1((const int8_t*) w62 + 0); + xnn_prefetch_to_l1((const int8_t*) w62 + 64); + xnn_prefetch_to_l1((const int8_t*) w62 + 128); + xnn_prefetch_to_l1((const int8_t*) w62 + 192); + xnn_prefetch_to_l1((const int8_t*) w62 + 256); + xnn_prefetch_to_l1((const int8_t*) w62 + 320); + xnn_prefetch_to_l1((const int8_t*) w62 + 384); + xnn_prefetch_to_l1((const int8_t*) w63 + 0); + xnn_prefetch_to_l1((const int8_t*) w63 + 64); + xnn_prefetch_to_l1((const int8_t*) w63 + 128); + xnn_prefetch_to_l1((const int8_t*) w63 + 192); + xnn_prefetch_to_l1((const int8_t*) w63 + 256); + xnn_prefetch_to_l1((const int8_t*) w63 + 320); + xnn_prefetch_to_l1((const int8_t*) w63 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc16 = _mm256_setzero_si256(); + __m256i vacc24 = _mm256_setzero_si256(); + __m256i vacc32 = _mm256_setzero_si256(); + __m256i vacc40 = _mm256_setzero_si256(); + __m256i vacc48 = _mm256_setzero_si256(); + __m256i vacc56 = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of 64x4 + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w9)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w10)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w11)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w12)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w13)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w14)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w15)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w16)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w17)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w18)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w19)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w20)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w21)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w22)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w23)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w24)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w25)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w26)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w27)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w28)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w29)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w30)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w31)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w32)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w33)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w34)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w35)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w36)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w37)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w38)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w39)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w40)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w41)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w42)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w43)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w44)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w45)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w46)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w47)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w48)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w49)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w50)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w51)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w52)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w53)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w54)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w55)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w56)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w57)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w58)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w59)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w60)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w61)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w62)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w63)), 0x80); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + xnn_prefetch_to_l1((const int8_t*) w16 + 448); + xnn_prefetch_to_l1((const int8_t*) w17 + 448); + xnn_prefetch_to_l1((const int8_t*) w18 + 448); + xnn_prefetch_to_l1((const int8_t*) w19 + 448); + xnn_prefetch_to_l1((const int8_t*) w20 + 448); + xnn_prefetch_to_l1((const int8_t*) w21 + 448); + xnn_prefetch_to_l1((const int8_t*) w22 + 448); + xnn_prefetch_to_l1((const int8_t*) w23 + 448); + xnn_prefetch_to_l1((const int8_t*) w24 + 448); + xnn_prefetch_to_l1((const int8_t*) w25 + 448); + xnn_prefetch_to_l1((const int8_t*) w26 + 448); + xnn_prefetch_to_l1((const int8_t*) w27 + 448); + xnn_prefetch_to_l1((const int8_t*) w28 + 448); + xnn_prefetch_to_l1((const int8_t*) w29 + 448); + xnn_prefetch_to_l1((const int8_t*) w30 + 448); + xnn_prefetch_to_l1((const int8_t*) w31 + 448); + xnn_prefetch_to_l1((const int8_t*) w32 + 448); + xnn_prefetch_to_l1((const int8_t*) w33 + 448); + xnn_prefetch_to_l1((const int8_t*) w34 + 448); + xnn_prefetch_to_l1((const int8_t*) w35 + 448); + xnn_prefetch_to_l1((const int8_t*) w36 + 448); + xnn_prefetch_to_l1((const int8_t*) w37 + 448); + xnn_prefetch_to_l1((const int8_t*) w38 + 448); + xnn_prefetch_to_l1((const int8_t*) w39 + 448); + xnn_prefetch_to_l1((const int8_t*) w40 + 448); + xnn_prefetch_to_l1((const int8_t*) w41 + 448); + xnn_prefetch_to_l1((const int8_t*) w42 + 448); + xnn_prefetch_to_l1((const int8_t*) w43 + 448); + xnn_prefetch_to_l1((const int8_t*) w44 + 448); + xnn_prefetch_to_l1((const int8_t*) w45 + 448); + xnn_prefetch_to_l1((const int8_t*) w46 + 448); + xnn_prefetch_to_l1((const int8_t*) w47 + 448); + xnn_prefetch_to_l1((const int8_t*) w48 + 448); + xnn_prefetch_to_l1((const int8_t*) w49 + 448); + xnn_prefetch_to_l1((const int8_t*) w50 + 448); + xnn_prefetch_to_l1((const int8_t*) w51 + 448); + xnn_prefetch_to_l1((const int8_t*) w52 + 448); + xnn_prefetch_to_l1((const int8_t*) w53 + 448); + xnn_prefetch_to_l1((const int8_t*) w54 + 448); + xnn_prefetch_to_l1((const int8_t*) w55 + 448); + xnn_prefetch_to_l1((const int8_t*) w56 + 448); + xnn_prefetch_to_l1((const int8_t*) w57 + 448); + xnn_prefetch_to_l1((const int8_t*) w58 + 448); + xnn_prefetch_to_l1((const int8_t*) w59 + 448); + xnn_prefetch_to_l1((const int8_t*) w60 + 448); + xnn_prefetch_to_l1((const int8_t*) w61 + 448); + xnn_prefetch_to_l1((const int8_t*) w62 + 448); + xnn_prefetch_to_l1((const int8_t*) w63 + 448); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + w16 += 4; + w17 += 4; + w18 += 4; + w19 += 4; + w20 += 4; + w21 += 4; + w22 += 4; + w23 += 4; + w24 += 4; + w25 += 4; + w26 += 4; + w27 += 4; + w28 += 4; + w29 += 4; + w30 += 4; + w31 += 4; + w32 += 4; + w33 += 4; + w34 += 4; + w35 += 4; + w36 += 4; + w37 += 4; + w38 += 4; + w39 += 4; + w40 += 4; + w41 += 4; + w42 += 4; + w43 += 4; + w44 += 4; + w45 += 4; + w46 += 4; + w47 += 4; + w48 += 4; + w49 += 4; + w50 += 4; + w51 += 4; + w52 += 4; + w53 += 4; + w54 += 4; + w55 += 4; + w56 += 4; + w57 += 4; + w58 += 4; + w59 += 4; + w60 += 4; + w61 += 4; + w62 += 4; + w63 += 4; + out += 256; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) safe_load_u32(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w9, k)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w10, k)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w11, k)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w12, k)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w13, k)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w14, k)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w15, k)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) safe_load_u32(w16, k)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w17, k)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w18, k)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w19, k)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w20, k)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w21, k)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w22, k)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w23, k)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) safe_load_u32(w24, k)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w25, k)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w26, k)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w27, k)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w28, k)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w29, k)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w30, k)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w31, k)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) safe_load_u32(w32, k)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w33, k)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w34, k)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w35, k)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w36, k)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w37, k)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w38, k)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w39, k)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) safe_load_u32(w40, k)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w41, k)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w42, k)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w43, k)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w44, k)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w45, k)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w46, k)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w47, k)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) safe_load_u32(w48, k)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w49, k)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w50, k)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w51, k)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w52, k)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w53, k)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w54, k)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w55, k)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) safe_load_u32(w56, k)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w57, k)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w58, k)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w59, k)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w60, k)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w61, k)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w62, k)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w63, k)), 0x80); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + w16 += k; + w17 += k; + w18 += k; + w19 += k; + w20 += k; + w21 += k; + w22 += k; + w23 += k; + w24 += k; + w25 += k; + w26 += k; + w27 += k; + w28 += k; + w29 += k; + w30 += k; + w31 += k; + w32 += k; + w33 += k; + w34 += k; + w35 += k; + w36 += k; + w37 += k; + w38 += k; + w39 += k; + w40 += k; + w41 += k; + w42 += k; + w43 += k; + w44 += k; + w45 += k; + w46 += k; + w47 += k; + w48 += k; + w49 += k; + w50 += k; + w51 += k; + w52 += k; + w53 += k; + w54 += k; + w55 += k; + w56 += k; + w57 += k; + w58 += k; + w59 += k; + w60 += k; + w61 += k; + w62 += k; + w63 += k; + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + xnn_prefetch_to_l1((const int8_t*) w16 + 448); + xnn_prefetch_to_l1((const int8_t*) w17 + 448); + xnn_prefetch_to_l1((const int8_t*) w18 + 448); + xnn_prefetch_to_l1((const int8_t*) w19 + 448); + xnn_prefetch_to_l1((const int8_t*) w20 + 448); + xnn_prefetch_to_l1((const int8_t*) w21 + 448); + xnn_prefetch_to_l1((const int8_t*) w22 + 448); + xnn_prefetch_to_l1((const int8_t*) w23 + 448); + xnn_prefetch_to_l1((const int8_t*) w24 + 448); + xnn_prefetch_to_l1((const int8_t*) w25 + 448); + xnn_prefetch_to_l1((const int8_t*) w26 + 448); + xnn_prefetch_to_l1((const int8_t*) w27 + 448); + xnn_prefetch_to_l1((const int8_t*) w28 + 448); + xnn_prefetch_to_l1((const int8_t*) w29 + 448); + xnn_prefetch_to_l1((const int8_t*) w30 + 448); + xnn_prefetch_to_l1((const int8_t*) w31 + 448); + xnn_prefetch_to_l1((const int8_t*) w32 + 448); + xnn_prefetch_to_l1((const int8_t*) w33 + 448); + xnn_prefetch_to_l1((const int8_t*) w34 + 448); + xnn_prefetch_to_l1((const int8_t*) w35 + 448); + xnn_prefetch_to_l1((const int8_t*) w36 + 448); + xnn_prefetch_to_l1((const int8_t*) w37 + 448); + xnn_prefetch_to_l1((const int8_t*) w38 + 448); + xnn_prefetch_to_l1((const int8_t*) w39 + 448); + xnn_prefetch_to_l1((const int8_t*) w40 + 448); + xnn_prefetch_to_l1((const int8_t*) w41 + 448); + xnn_prefetch_to_l1((const int8_t*) w42 + 448); + xnn_prefetch_to_l1((const int8_t*) w43 + 448); + xnn_prefetch_to_l1((const int8_t*) w44 + 448); + xnn_prefetch_to_l1((const int8_t*) w45 + 448); + xnn_prefetch_to_l1((const int8_t*) w46 + 448); + xnn_prefetch_to_l1((const int8_t*) w47 + 448); + xnn_prefetch_to_l1((const int8_t*) w48 + 448); + xnn_prefetch_to_l1((const int8_t*) w49 + 448); + xnn_prefetch_to_l1((const int8_t*) w50 + 448); + xnn_prefetch_to_l1((const int8_t*) w51 + 448); + xnn_prefetch_to_l1((const int8_t*) w52 + 448); + xnn_prefetch_to_l1((const int8_t*) w53 + 448); + xnn_prefetch_to_l1((const int8_t*) w54 + 448); + xnn_prefetch_to_l1((const int8_t*) w55 + 448); + xnn_prefetch_to_l1((const int8_t*) w56 + 448); + xnn_prefetch_to_l1((const int8_t*) w57 + 448); + xnn_prefetch_to_l1((const int8_t*) w58 + 448); + xnn_prefetch_to_l1((const int8_t*) w59 + 448); + xnn_prefetch_to_l1((const int8_t*) w60 + 448); + xnn_prefetch_to_l1((const int8_t*) w61 + 448); + xnn_prefetch_to_l1((const int8_t*) w62 + 448); + xnn_prefetch_to_l1((const int8_t*) w63 + 448); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + out += 256; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vksum8 = _mm256_mullo_epi32(vacc8, vzeropoint); + __m256i vksum16 = _mm256_mullo_epi32(vacc16, vzeropoint); + __m256i vksum24 = _mm256_mullo_epi32(vacc24, vzeropoint); + __m256i vksum32 = _mm256_mullo_epi32(vacc32, vzeropoint); + __m256i vksum40 = _mm256_mullo_epi32(vacc40, vzeropoint); + __m256i vksum48 = _mm256_mullo_epi32(vacc48, vzeropoint); + __m256i vksum56 = _mm256_mullo_epi32(vacc56, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + __m256i vpack16 = _mm256_loadu_si256((const __m256i*) (packed_b + 16)); + __m256i vpack24 = _mm256_loadu_si256((const __m256i*) (packed_b + 24)); + __m256i vpack32 = _mm256_loadu_si256((const __m256i*) (packed_b + 32)); + __m256i vpack40 = _mm256_loadu_si256((const __m256i*) (packed_b + 40)); + __m256i vpack48 = _mm256_loadu_si256((const __m256i*) (packed_b + 48)); + __m256i vpack56 = _mm256_loadu_si256((const __m256i*) (packed_b + 56)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + vpack16 = _mm256_sub_epi32(vpack16, vksum16); + vpack24 = _mm256_sub_epi32(vpack24, vksum24); + vpack32 = _mm256_sub_epi32(vpack32, vksum32); + vpack40 = _mm256_sub_epi32(vpack40, vksum40); + vpack48 = _mm256_sub_epi32(vpack48, vksum48); + vpack56 = _mm256_sub_epi32(vpack56, vksum56); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + _mm256_storeu_si256((__m256i *) (packed_b + 16), vpack16); + _mm256_storeu_si256((__m256i *) (packed_b + 24), vpack24); + _mm256_storeu_si256((__m256i *) (packed_b + 32), vpack32); + _mm256_storeu_si256((__m256i *) (packed_b + 40), vpack40); + _mm256_storeu_si256((__m256i *) (packed_b + 48), vpack48); + _mm256_storeu_si256((__m256i *) (packed_b + 56), vpack56); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w63; + } + + // NC remainder (1..63) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 63); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 64), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 96), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 128), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 160), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 192), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 224), _mm256_setzero_si256()); + } + out += 64 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + const int8_t* w16 = w15 + kc; + if XNN_UNPREDICTABLE(n <= 16) { + w16 = w15; + } + const int8_t* w17 = w16 + kc; + if XNN_UNPREDICTABLE(n < 18) { + w17 = w16; + } + const int8_t* w18 = w17 + kc; + if XNN_UNPREDICTABLE(n <= 18) { + w18 = w17; + } + const int8_t* w19 = w18 + kc; + if XNN_UNPREDICTABLE(n < 20) { + w19 = w18; + } + const int8_t* w20 = w19 + kc; + if XNN_UNPREDICTABLE(n <= 20) { + w20 = w19; + } + const int8_t* w21 = w20 + kc; + if XNN_UNPREDICTABLE(n < 22) { + w21 = w20; + } + const int8_t* w22 = w21 + kc; + if XNN_UNPREDICTABLE(n <= 22) { + w22 = w21; + } + const int8_t* w23 = w22 + kc; + if XNN_UNPREDICTABLE(n < 24) { + w23 = w22; + } + const int8_t* w24 = w23 + kc; + if XNN_UNPREDICTABLE(n <= 24) { + w24 = w23; + } + const int8_t* w25 = w24 + kc; + if XNN_UNPREDICTABLE(n < 26) { + w25 = w24; + } + const int8_t* w26 = w25 + kc; + if XNN_UNPREDICTABLE(n <= 26) { + w26 = w25; + } + const int8_t* w27 = w26 + kc; + if XNN_UNPREDICTABLE(n < 28) { + w27 = w26; + } + const int8_t* w28 = w27 + kc; + if XNN_UNPREDICTABLE(n <= 28) { + w28 = w27; + } + const int8_t* w29 = w28 + kc; + if XNN_UNPREDICTABLE(n < 30) { + w29 = w28; + } + const int8_t* w30 = w29 + kc; + if XNN_UNPREDICTABLE(n <= 30) { + w30 = w29; + } + const int8_t* w31 = w30 + kc; + if XNN_UNPREDICTABLE(n < 32) { + w31 = w30; + } + const int8_t* w32 = w31 + kc; + if XNN_UNPREDICTABLE(n <= 32) { + w32 = w31; + } + const int8_t* w33 = w32 + kc; + if XNN_UNPREDICTABLE(n < 34) { + w33 = w32; + } + const int8_t* w34 = w33 + kc; + if XNN_UNPREDICTABLE(n <= 34) { + w34 = w33; + } + const int8_t* w35 = w34 + kc; + if XNN_UNPREDICTABLE(n < 36) { + w35 = w34; + } + const int8_t* w36 = w35 + kc; + if XNN_UNPREDICTABLE(n <= 36) { + w36 = w35; + } + const int8_t* w37 = w36 + kc; + if XNN_UNPREDICTABLE(n < 38) { + w37 = w36; + } + const int8_t* w38 = w37 + kc; + if XNN_UNPREDICTABLE(n <= 38) { + w38 = w37; + } + const int8_t* w39 = w38 + kc; + if XNN_UNPREDICTABLE(n < 40) { + w39 = w38; + } + const int8_t* w40 = w39 + kc; + if XNN_UNPREDICTABLE(n <= 40) { + w40 = w39; + } + const int8_t* w41 = w40 + kc; + if XNN_UNPREDICTABLE(n < 42) { + w41 = w40; + } + const int8_t* w42 = w41 + kc; + if XNN_UNPREDICTABLE(n <= 42) { + w42 = w41; + } + const int8_t* w43 = w42 + kc; + if XNN_UNPREDICTABLE(n < 44) { + w43 = w42; + } + const int8_t* w44 = w43 + kc; + if XNN_UNPREDICTABLE(n <= 44) { + w44 = w43; + } + const int8_t* w45 = w44 + kc; + if XNN_UNPREDICTABLE(n < 46) { + w45 = w44; + } + const int8_t* w46 = w45 + kc; + if XNN_UNPREDICTABLE(n <= 46) { + w46 = w45; + } + const int8_t* w47 = w46 + kc; + if XNN_UNPREDICTABLE(n < 48) { + w47 = w46; + } + const int8_t* w48 = w47 + kc; + if XNN_UNPREDICTABLE(n <= 48) { + w48 = w47; + } + const int8_t* w49 = w48 + kc; + if XNN_UNPREDICTABLE(n < 50) { + w49 = w48; + } + const int8_t* w50 = w49 + kc; + if XNN_UNPREDICTABLE(n <= 50) { + w50 = w49; + } + const int8_t* w51 = w50 + kc; + if XNN_UNPREDICTABLE(n < 52) { + w51 = w50; + } + const int8_t* w52 = w51 + kc; + if XNN_UNPREDICTABLE(n <= 52) { + w52 = w51; + } + const int8_t* w53 = w52 + kc; + if XNN_UNPREDICTABLE(n < 54) { + w53 = w52; + } + const int8_t* w54 = w53 + kc; + if XNN_UNPREDICTABLE(n <= 54) { + w54 = w53; + } + const int8_t* w55 = w54 + kc; + if XNN_UNPREDICTABLE(n < 56) { + w55 = w54; + } + const int8_t* w56 = w55 + kc; + if XNN_UNPREDICTABLE(n <= 56) { + w56 = w55; + } + const int8_t* w57 = w56 + kc; + if XNN_UNPREDICTABLE(n < 58) { + w57 = w56; + } + const int8_t* w58 = w57 + kc; + if XNN_UNPREDICTABLE(n <= 58) { + w58 = w57; + } + const int8_t* w59 = w58 + kc; + if XNN_UNPREDICTABLE(n < 60) { + w59 = w58; + } + const int8_t* w60 = w59 + kc; + if XNN_UNPREDICTABLE(n <= 60) { + w60 = w59; + } + const int8_t* w61 = w60 + kc; + if XNN_UNPREDICTABLE(n < 62) { + w61 = w60; + } + const int8_t* w62 = w61 + kc; + if XNN_UNPREDICTABLE(n <= 62) { + w62 = w61; + } + const int8_t* w63 = w62 + kc; + if XNN_UNPREDICTABLE(n < 64) { + w63 = w62; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc16 = _mm256_setzero_si256(); + __m256i vacc24 = _mm256_setzero_si256(); + __m256i vacc32 = _mm256_setzero_si256(); + __m256i vacc40 = _mm256_setzero_si256(); + __m256i vacc48 = _mm256_setzero_si256(); + __m256i vacc56 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 64x4 + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w9)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w10)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w11)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w12)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w13)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w14)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w15)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w16)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w17)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w18)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w19)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w20)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w21)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w22)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w23)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w24)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w25)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w26)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w27)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w28)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w29)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w30)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w31)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w32)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w33)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w34)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w35)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w36)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w37)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w38)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w39)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w40)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w41)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w42)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w43)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w44)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w45)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w46)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w47)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w48)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w49)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w50)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w51)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w52)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w53)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w54)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w55)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w56)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w57)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w58)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w59)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w60)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w61)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w62)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w63)), 0x80); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + xnn_prefetch_to_l1((const int8_t*) w16 + 448); + xnn_prefetch_to_l1((const int8_t*) w17 + 448); + xnn_prefetch_to_l1((const int8_t*) w18 + 448); + xnn_prefetch_to_l1((const int8_t*) w19 + 448); + xnn_prefetch_to_l1((const int8_t*) w20 + 448); + xnn_prefetch_to_l1((const int8_t*) w21 + 448); + xnn_prefetch_to_l1((const int8_t*) w22 + 448); + xnn_prefetch_to_l1((const int8_t*) w23 + 448); + xnn_prefetch_to_l1((const int8_t*) w24 + 448); + xnn_prefetch_to_l1((const int8_t*) w25 + 448); + xnn_prefetch_to_l1((const int8_t*) w26 + 448); + xnn_prefetch_to_l1((const int8_t*) w27 + 448); + xnn_prefetch_to_l1((const int8_t*) w28 + 448); + xnn_prefetch_to_l1((const int8_t*) w29 + 448); + xnn_prefetch_to_l1((const int8_t*) w30 + 448); + xnn_prefetch_to_l1((const int8_t*) w31 + 448); + xnn_prefetch_to_l1((const int8_t*) w32 + 448); + xnn_prefetch_to_l1((const int8_t*) w33 + 448); + xnn_prefetch_to_l1((const int8_t*) w34 + 448); + xnn_prefetch_to_l1((const int8_t*) w35 + 448); + xnn_prefetch_to_l1((const int8_t*) w36 + 448); + xnn_prefetch_to_l1((const int8_t*) w37 + 448); + xnn_prefetch_to_l1((const int8_t*) w38 + 448); + xnn_prefetch_to_l1((const int8_t*) w39 + 448); + xnn_prefetch_to_l1((const int8_t*) w40 + 448); + xnn_prefetch_to_l1((const int8_t*) w41 + 448); + xnn_prefetch_to_l1((const int8_t*) w42 + 448); + xnn_prefetch_to_l1((const int8_t*) w43 + 448); + xnn_prefetch_to_l1((const int8_t*) w44 + 448); + xnn_prefetch_to_l1((const int8_t*) w45 + 448); + xnn_prefetch_to_l1((const int8_t*) w46 + 448); + xnn_prefetch_to_l1((const int8_t*) w47 + 448); + xnn_prefetch_to_l1((const int8_t*) w48 + 448); + xnn_prefetch_to_l1((const int8_t*) w49 + 448); + xnn_prefetch_to_l1((const int8_t*) w50 + 448); + xnn_prefetch_to_l1((const int8_t*) w51 + 448); + xnn_prefetch_to_l1((const int8_t*) w52 + 448); + xnn_prefetch_to_l1((const int8_t*) w53 + 448); + xnn_prefetch_to_l1((const int8_t*) w54 + 448); + xnn_prefetch_to_l1((const int8_t*) w55 + 448); + xnn_prefetch_to_l1((const int8_t*) w56 + 448); + xnn_prefetch_to_l1((const int8_t*) w57 + 448); + xnn_prefetch_to_l1((const int8_t*) w58 + 448); + xnn_prefetch_to_l1((const int8_t*) w59 + 448); + xnn_prefetch_to_l1((const int8_t*) w60 + 448); + xnn_prefetch_to_l1((const int8_t*) w61 + 448); + xnn_prefetch_to_l1((const int8_t*) w62 + 448); + xnn_prefetch_to_l1((const int8_t*) w63 + 448); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + w16 += 4; + w17 += 4; + w18 += 4; + w19 += 4; + w20 += 4; + w21 += 4; + w22 += 4; + w23 += 4; + w24 += 4; + w25 += 4; + w26 += 4; + w27 += 4; + w28 += 4; + w29 += 4; + w30 += 4; + w31 += 4; + w32 += 4; + w33 += 4; + w34 += 4; + w35 += 4; + w36 += 4; + w37 += 4; + w38 += 4; + w39 += 4; + w40 += 4; + w41 += 4; + w42 += 4; + w43 += 4; + w44 += 4; + w45 += 4; + w46 += 4; + w47 += 4; + w48 += 4; + w49 += 4; + w50 += 4; + w51 += 4; + w52 += 4; + w53 += 4; + w54 += 4; + w55 += 4; + w56 += 4; + w57 += 4; + w58 += 4; + w59 += 4; + w60 += 4; + w61 += 4; + w62 += 4; + w63 += 4; + out += 256; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) safe_load_u32(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w9, k)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w10, k)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w11, k)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w12, k)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w13, k)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w14, k)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w15, k)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) safe_load_u32(w16, k)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w17, k)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w18, k)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w19, k)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w20, k)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w21, k)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w22, k)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w23, k)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) safe_load_u32(w24, k)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w25, k)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w26, k)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w27, k)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w28, k)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w29, k)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w30, k)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w31, k)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) safe_load_u32(w32, k)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w33, k)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w34, k)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w35, k)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w36, k)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w37, k)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w38, k)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w39, k)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) safe_load_u32(w40, k)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w41, k)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w42, k)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w43, k)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w44, k)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w45, k)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w46, k)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w47, k)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) safe_load_u32(w48, k)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w49, k)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w50, k)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w51, k)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w52, k)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w53, k)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w54, k)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w55, k)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) safe_load_u32(w56, k)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w57, k)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w58, k)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w59, k)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w60, k)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w61, k)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w62, k)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w63, k)), 0x80); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + w16 += k; + w17 += k; + w18 += k; + w19 += k; + w20 += k; + w21 += k; + w22 += k; + w23 += k; + w24 += k; + w25 += k; + w26 += k; + w27 += k; + w28 += k; + w29 += k; + w30 += k; + w31 += k; + w32 += k; + w33 += k; + w34 += k; + w35 += k; + w36 += k; + w37 += k; + w38 += k; + w39 += k; + w40 += k; + w41 += k; + w42 += k; + w43 += k; + w44 += k; + w45 += k; + w46 += k; + w47 += k; + w48 += k; + w49 += k; + w50 += k; + w51 += k; + w52 += k; + w53 += k; + w54 += k; + w55 += k; + w56 += k; + w57 += k; + w58 += k; + w59 += k; + w60 += k; + w61 += k; + w62 += k; + w63 += k; + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + xnn_prefetch_to_l1((const int8_t*) w16 + 448); + xnn_prefetch_to_l1((const int8_t*) w17 + 448); + xnn_prefetch_to_l1((const int8_t*) w18 + 448); + xnn_prefetch_to_l1((const int8_t*) w19 + 448); + xnn_prefetch_to_l1((const int8_t*) w20 + 448); + xnn_prefetch_to_l1((const int8_t*) w21 + 448); + xnn_prefetch_to_l1((const int8_t*) w22 + 448); + xnn_prefetch_to_l1((const int8_t*) w23 + 448); + xnn_prefetch_to_l1((const int8_t*) w24 + 448); + xnn_prefetch_to_l1((const int8_t*) w25 + 448); + xnn_prefetch_to_l1((const int8_t*) w26 + 448); + xnn_prefetch_to_l1((const int8_t*) w27 + 448); + xnn_prefetch_to_l1((const int8_t*) w28 + 448); + xnn_prefetch_to_l1((const int8_t*) w29 + 448); + xnn_prefetch_to_l1((const int8_t*) w30 + 448); + xnn_prefetch_to_l1((const int8_t*) w31 + 448); + xnn_prefetch_to_l1((const int8_t*) w32 + 448); + xnn_prefetch_to_l1((const int8_t*) w33 + 448); + xnn_prefetch_to_l1((const int8_t*) w34 + 448); + xnn_prefetch_to_l1((const int8_t*) w35 + 448); + xnn_prefetch_to_l1((const int8_t*) w36 + 448); + xnn_prefetch_to_l1((const int8_t*) w37 + 448); + xnn_prefetch_to_l1((const int8_t*) w38 + 448); + xnn_prefetch_to_l1((const int8_t*) w39 + 448); + xnn_prefetch_to_l1((const int8_t*) w40 + 448); + xnn_prefetch_to_l1((const int8_t*) w41 + 448); + xnn_prefetch_to_l1((const int8_t*) w42 + 448); + xnn_prefetch_to_l1((const int8_t*) w43 + 448); + xnn_prefetch_to_l1((const int8_t*) w44 + 448); + xnn_prefetch_to_l1((const int8_t*) w45 + 448); + xnn_prefetch_to_l1((const int8_t*) w46 + 448); + xnn_prefetch_to_l1((const int8_t*) w47 + 448); + xnn_prefetch_to_l1((const int8_t*) w48 + 448); + xnn_prefetch_to_l1((const int8_t*) w49 + 448); + xnn_prefetch_to_l1((const int8_t*) w50 + 448); + xnn_prefetch_to_l1((const int8_t*) w51 + 448); + xnn_prefetch_to_l1((const int8_t*) w52 + 448); + xnn_prefetch_to_l1((const int8_t*) w53 + 448); + xnn_prefetch_to_l1((const int8_t*) w54 + 448); + xnn_prefetch_to_l1((const int8_t*) w55 + 448); + xnn_prefetch_to_l1((const int8_t*) w56 + 448); + xnn_prefetch_to_l1((const int8_t*) w57 + 448); + xnn_prefetch_to_l1((const int8_t*) w58 + 448); + xnn_prefetch_to_l1((const int8_t*) w59 + 448); + xnn_prefetch_to_l1((const int8_t*) w60 + 448); + xnn_prefetch_to_l1((const int8_t*) w61 + 448); + xnn_prefetch_to_l1((const int8_t*) w62 + 448); + xnn_prefetch_to_l1((const int8_t*) w63 + 448); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + out += 256; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vksum8 = _mm256_mullo_epi32(vacc8, vzeropoint); + __m256i vksum16 = _mm256_mullo_epi32(vacc16, vzeropoint); + __m256i vksum24 = _mm256_mullo_epi32(vacc24, vzeropoint); + __m256i vksum32 = _mm256_mullo_epi32(vacc32, vzeropoint); + __m256i vksum40 = _mm256_mullo_epi32(vacc40, vzeropoint); + __m256i vksum48 = _mm256_mullo_epi32(vacc48, vzeropoint); + __m256i vksum56 = _mm256_mullo_epi32(vacc56, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + __m256i vpack16 = _mm256_loadu_si256((const __m256i*) (packed_b + 16)); + __m256i vpack24 = _mm256_loadu_si256((const __m256i*) (packed_b + 24)); + __m256i vpack32 = _mm256_loadu_si256((const __m256i*) (packed_b + 32)); + __m256i vpack40 = _mm256_loadu_si256((const __m256i*) (packed_b + 40)); + __m256i vpack48 = _mm256_loadu_si256((const __m256i*) (packed_b + 48)); + __m256i vpack56 = _mm256_loadu_si256((const __m256i*) (packed_b + 56)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + vpack16 = _mm256_sub_epi32(vpack16, vksum16); + vpack24 = _mm256_sub_epi32(vpack24, vksum24); + vpack32 = _mm256_sub_epi32(vpack32, vksum32); + vpack40 = _mm256_sub_epi32(vpack40, vksum40); + vpack48 = _mm256_sub_epi32(vpack48, vksum48); + vpack56 = _mm256_sub_epi32(vpack56, vksum56); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + _mm256_storeu_si256((__m256i *) (packed_b + 16), vpack16); + _mm256_storeu_si256((__m256i *) (packed_b + 24), vpack24); + _mm256_storeu_si256((__m256i *) (packed_b + 32), vpack32); + _mm256_storeu_si256((__m256i *) (packed_b + 40), vpack40); + _mm256_storeu_si256((__m256i *) (packed_b + 48), vpack48); + _mm256_storeu_si256((__m256i *) (packed_b + 56), vpack56); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni.c b/src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni.c new file mode 100644 index 000000000000..a404d2aa5896 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-avx256vnni.c @@ -0,0 +1,1147 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/c4-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + +XNN_INLINE static uint32_t safe_load_u32(const void* src, size_t k) { + uint32_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < k; ++i) { + value |= (uint32_t) bytes[i] << (i * 8); + } + return value; +} + + +void xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 64); + assert(kr == 4); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); + + do { + // NC main loop multiple of 64 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 64; n -= 64) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + const int8_t* w16 = w15 + kc; + const int8_t* w17 = w16 + kc; + const int8_t* w18 = w17 + kc; + const int8_t* w19 = w18 + kc; + const int8_t* w20 = w19 + kc; + const int8_t* w21 = w20 + kc; + const int8_t* w22 = w21 + kc; + const int8_t* w23 = w22 + kc; + const int8_t* w24 = w23 + kc; + const int8_t* w25 = w24 + kc; + const int8_t* w26 = w25 + kc; + const int8_t* w27 = w26 + kc; + const int8_t* w28 = w27 + kc; + const int8_t* w29 = w28 + kc; + const int8_t* w30 = w29 + kc; + const int8_t* w31 = w30 + kc; + const int8_t* w32 = w31 + kc; + const int8_t* w33 = w32 + kc; + const int8_t* w34 = w33 + kc; + const int8_t* w35 = w34 + kc; + const int8_t* w36 = w35 + kc; + const int8_t* w37 = w36 + kc; + const int8_t* w38 = w37 + kc; + const int8_t* w39 = w38 + kc; + const int8_t* w40 = w39 + kc; + const int8_t* w41 = w40 + kc; + const int8_t* w42 = w41 + kc; + const int8_t* w43 = w42 + kc; + const int8_t* w44 = w43 + kc; + const int8_t* w45 = w44 + kc; + const int8_t* w46 = w45 + kc; + const int8_t* w47 = w46 + kc; + const int8_t* w48 = w47 + kc; + const int8_t* w49 = w48 + kc; + const int8_t* w50 = w49 + kc; + const int8_t* w51 = w50 + kc; + const int8_t* w52 = w51 + kc; + const int8_t* w53 = w52 + kc; + const int8_t* w54 = w53 + kc; + const int8_t* w55 = w54 + kc; + const int8_t* w56 = w55 + kc; + const int8_t* w57 = w56 + kc; + const int8_t* w58 = w57 + kc; + const int8_t* w59 = w58 + kc; + const int8_t* w60 = w59 + kc; + const int8_t* w61 = w60 + kc; + const int8_t* w62 = w61 + kc; + const int8_t* w63 = w62 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + const __m256i vb16 = _mm256_loadu_si256((const __m256i*) (b + 16)); + const __m256i vb24 = _mm256_loadu_si256((const __m256i*) (b + 24)); + const __m256i vb32 = _mm256_loadu_si256((const __m256i*) (b + 32)); + const __m256i vb40 = _mm256_loadu_si256((const __m256i*) (b + 40)); + const __m256i vb48 = _mm256_loadu_si256((const __m256i*) (b + 48)); + const __m256i vb56 = _mm256_loadu_si256((const __m256i*) (b + 56)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + _mm256_storeu_si256((__m256i*) (out + 64), vb16); + _mm256_storeu_si256((__m256i*) (out + 96), vb24); + _mm256_storeu_si256((__m256i*) (out + 128), vb32); + _mm256_storeu_si256((__m256i*) (out + 160), vb40); + _mm256_storeu_si256((__m256i*) (out + 192), vb48); + _mm256_storeu_si256((__m256i*) (out + 224), vb56); + b += 64; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 64), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 96), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 128), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 160), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 192), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 224), _mm256_setzero_si256()); + } + out += 64 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc16 = _mm256_setzero_si256(); + __m256i vacc24 = _mm256_setzero_si256(); + __m256i vacc32 = _mm256_setzero_si256(); + __m256i vacc40 = _mm256_setzero_si256(); + __m256i vacc48 = _mm256_setzero_si256(); + __m256i vacc56 = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of 64x4 + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w9)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w10)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w11)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w12)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w13)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w14)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w15)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w16)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w17)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w18)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w19)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w20)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w21)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w22)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w23)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w24)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w25)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w26)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w27)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w28)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w29)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w30)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w31)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w32)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w33)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w34)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w35)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w36)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w37)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w38)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w39)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w40)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w41)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w42)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w43)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w44)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w45)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w46)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w47)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w48)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w49)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w50)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w51)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w52)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w53)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w54)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w55)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w56)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w57)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w58)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w59)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w60)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w61)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w62)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w63)), 0x80); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + w16 += 4; + w17 += 4; + w18 += 4; + w19 += 4; + w20 += 4; + w21 += 4; + w22 += 4; + w23 += 4; + w24 += 4; + w25 += 4; + w26 += 4; + w27 += 4; + w28 += 4; + w29 += 4; + w30 += 4; + w31 += 4; + w32 += 4; + w33 += 4; + w34 += 4; + w35 += 4; + w36 += 4; + w37 += 4; + w38 += 4; + w39 += 4; + w40 += 4; + w41 += 4; + w42 += 4; + w43 += 4; + w44 += 4; + w45 += 4; + w46 += 4; + w47 += 4; + w48 += 4; + w49 += 4; + w50 += 4; + w51 += 4; + w52 += 4; + w53 += 4; + w54 += 4; + w55 += 4; + w56 += 4; + w57 += 4; + w58 += 4; + w59 += 4; + w60 += 4; + w61 += 4; + w62 += 4; + w63 += 4; + out += 256; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) safe_load_u32(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w9, k)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w10, k)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w11, k)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w12, k)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w13, k)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w14, k)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w15, k)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) safe_load_u32(w16, k)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w17, k)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w18, k)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w19, k)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w20, k)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w21, k)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w22, k)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w23, k)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) safe_load_u32(w24, k)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w25, k)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w26, k)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w27, k)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w28, k)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w29, k)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w30, k)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w31, k)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) safe_load_u32(w32, k)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w33, k)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w34, k)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w35, k)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w36, k)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w37, k)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w38, k)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w39, k)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) safe_load_u32(w40, k)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w41, k)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w42, k)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w43, k)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w44, k)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w45, k)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w46, k)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w47, k)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) safe_load_u32(w48, k)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w49, k)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w50, k)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w51, k)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w52, k)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w53, k)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w54, k)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w55, k)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) safe_load_u32(w56, k)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w57, k)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w58, k)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w59, k)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w60, k)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w61, k)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w62, k)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w63, k)), 0x80); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + w16 += k; + w17 += k; + w18 += k; + w19 += k; + w20 += k; + w21 += k; + w22 += k; + w23 += k; + w24 += k; + w25 += k; + w26 += k; + w27 += k; + w28 += k; + w29 += k; + w30 += k; + w31 += k; + w32 += k; + w33 += k; + w34 += k; + w35 += k; + w36 += k; + w37 += k; + w38 += k; + w39 += k; + w40 += k; + w41 += k; + w42 += k; + w43 += k; + w44 += k; + w45 += k; + w46 += k; + w47 += k; + w48 += k; + w49 += k; + w50 += k; + w51 += k; + w52 += k; + w53 += k; + w54 += k; + w55 += k; + w56 += k; + w57 += k; + w58 += k; + w59 += k; + w60 += k; + w61 += k; + w62 += k; + w63 += k; + + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + out += 256; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vksum8 = _mm256_mullo_epi32(vacc8, vzeropoint); + __m256i vksum16 = _mm256_mullo_epi32(vacc16, vzeropoint); + __m256i vksum24 = _mm256_mullo_epi32(vacc24, vzeropoint); + __m256i vksum32 = _mm256_mullo_epi32(vacc32, vzeropoint); + __m256i vksum40 = _mm256_mullo_epi32(vacc40, vzeropoint); + __m256i vksum48 = _mm256_mullo_epi32(vacc48, vzeropoint); + __m256i vksum56 = _mm256_mullo_epi32(vacc56, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + __m256i vpack16 = _mm256_loadu_si256((const __m256i*) (packed_b + 16)); + __m256i vpack24 = _mm256_loadu_si256((const __m256i*) (packed_b + 24)); + __m256i vpack32 = _mm256_loadu_si256((const __m256i*) (packed_b + 32)); + __m256i vpack40 = _mm256_loadu_si256((const __m256i*) (packed_b + 40)); + __m256i vpack48 = _mm256_loadu_si256((const __m256i*) (packed_b + 48)); + __m256i vpack56 = _mm256_loadu_si256((const __m256i*) (packed_b + 56)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + vpack16 = _mm256_sub_epi32(vpack16, vksum16); + vpack24 = _mm256_sub_epi32(vpack24, vksum24); + vpack32 = _mm256_sub_epi32(vpack32, vksum32); + vpack40 = _mm256_sub_epi32(vpack40, vksum40); + vpack48 = _mm256_sub_epi32(vpack48, vksum48); + vpack56 = _mm256_sub_epi32(vpack56, vksum56); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + _mm256_storeu_si256((__m256i *) (packed_b + 16), vpack16); + _mm256_storeu_si256((__m256i *) (packed_b + 24), vpack24); + _mm256_storeu_si256((__m256i *) (packed_b + 32), vpack32); + _mm256_storeu_si256((__m256i *) (packed_b + 40), vpack40); + _mm256_storeu_si256((__m256i *) (packed_b + 48), vpack48); + _mm256_storeu_si256((__m256i *) (packed_b + 56), vpack56); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w63; + } + + // NC remainder (1..63) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 63); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 64), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 96), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 128), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 160), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 192), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 224), _mm256_setzero_si256()); + } + out += 64 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + const int8_t* w16 = w15 + kc; + if XNN_UNPREDICTABLE(n <= 16) { + w16 = w15; + } + const int8_t* w17 = w16 + kc; + if XNN_UNPREDICTABLE(n < 18) { + w17 = w16; + } + const int8_t* w18 = w17 + kc; + if XNN_UNPREDICTABLE(n <= 18) { + w18 = w17; + } + const int8_t* w19 = w18 + kc; + if XNN_UNPREDICTABLE(n < 20) { + w19 = w18; + } + const int8_t* w20 = w19 + kc; + if XNN_UNPREDICTABLE(n <= 20) { + w20 = w19; + } + const int8_t* w21 = w20 + kc; + if XNN_UNPREDICTABLE(n < 22) { + w21 = w20; + } + const int8_t* w22 = w21 + kc; + if XNN_UNPREDICTABLE(n <= 22) { + w22 = w21; + } + const int8_t* w23 = w22 + kc; + if XNN_UNPREDICTABLE(n < 24) { + w23 = w22; + } + const int8_t* w24 = w23 + kc; + if XNN_UNPREDICTABLE(n <= 24) { + w24 = w23; + } + const int8_t* w25 = w24 + kc; + if XNN_UNPREDICTABLE(n < 26) { + w25 = w24; + } + const int8_t* w26 = w25 + kc; + if XNN_UNPREDICTABLE(n <= 26) { + w26 = w25; + } + const int8_t* w27 = w26 + kc; + if XNN_UNPREDICTABLE(n < 28) { + w27 = w26; + } + const int8_t* w28 = w27 + kc; + if XNN_UNPREDICTABLE(n <= 28) { + w28 = w27; + } + const int8_t* w29 = w28 + kc; + if XNN_UNPREDICTABLE(n < 30) { + w29 = w28; + } + const int8_t* w30 = w29 + kc; + if XNN_UNPREDICTABLE(n <= 30) { + w30 = w29; + } + const int8_t* w31 = w30 + kc; + if XNN_UNPREDICTABLE(n < 32) { + w31 = w30; + } + const int8_t* w32 = w31 + kc; + if XNN_UNPREDICTABLE(n <= 32) { + w32 = w31; + } + const int8_t* w33 = w32 + kc; + if XNN_UNPREDICTABLE(n < 34) { + w33 = w32; + } + const int8_t* w34 = w33 + kc; + if XNN_UNPREDICTABLE(n <= 34) { + w34 = w33; + } + const int8_t* w35 = w34 + kc; + if XNN_UNPREDICTABLE(n < 36) { + w35 = w34; + } + const int8_t* w36 = w35 + kc; + if XNN_UNPREDICTABLE(n <= 36) { + w36 = w35; + } + const int8_t* w37 = w36 + kc; + if XNN_UNPREDICTABLE(n < 38) { + w37 = w36; + } + const int8_t* w38 = w37 + kc; + if XNN_UNPREDICTABLE(n <= 38) { + w38 = w37; + } + const int8_t* w39 = w38 + kc; + if XNN_UNPREDICTABLE(n < 40) { + w39 = w38; + } + const int8_t* w40 = w39 + kc; + if XNN_UNPREDICTABLE(n <= 40) { + w40 = w39; + } + const int8_t* w41 = w40 + kc; + if XNN_UNPREDICTABLE(n < 42) { + w41 = w40; + } + const int8_t* w42 = w41 + kc; + if XNN_UNPREDICTABLE(n <= 42) { + w42 = w41; + } + const int8_t* w43 = w42 + kc; + if XNN_UNPREDICTABLE(n < 44) { + w43 = w42; + } + const int8_t* w44 = w43 + kc; + if XNN_UNPREDICTABLE(n <= 44) { + w44 = w43; + } + const int8_t* w45 = w44 + kc; + if XNN_UNPREDICTABLE(n < 46) { + w45 = w44; + } + const int8_t* w46 = w45 + kc; + if XNN_UNPREDICTABLE(n <= 46) { + w46 = w45; + } + const int8_t* w47 = w46 + kc; + if XNN_UNPREDICTABLE(n < 48) { + w47 = w46; + } + const int8_t* w48 = w47 + kc; + if XNN_UNPREDICTABLE(n <= 48) { + w48 = w47; + } + const int8_t* w49 = w48 + kc; + if XNN_UNPREDICTABLE(n < 50) { + w49 = w48; + } + const int8_t* w50 = w49 + kc; + if XNN_UNPREDICTABLE(n <= 50) { + w50 = w49; + } + const int8_t* w51 = w50 + kc; + if XNN_UNPREDICTABLE(n < 52) { + w51 = w50; + } + const int8_t* w52 = w51 + kc; + if XNN_UNPREDICTABLE(n <= 52) { + w52 = w51; + } + const int8_t* w53 = w52 + kc; + if XNN_UNPREDICTABLE(n < 54) { + w53 = w52; + } + const int8_t* w54 = w53 + kc; + if XNN_UNPREDICTABLE(n <= 54) { + w54 = w53; + } + const int8_t* w55 = w54 + kc; + if XNN_UNPREDICTABLE(n < 56) { + w55 = w54; + } + const int8_t* w56 = w55 + kc; + if XNN_UNPREDICTABLE(n <= 56) { + w56 = w55; + } + const int8_t* w57 = w56 + kc; + if XNN_UNPREDICTABLE(n < 58) { + w57 = w56; + } + const int8_t* w58 = w57 + kc; + if XNN_UNPREDICTABLE(n <= 58) { + w58 = w57; + } + const int8_t* w59 = w58 + kc; + if XNN_UNPREDICTABLE(n < 60) { + w59 = w58; + } + const int8_t* w60 = w59 + kc; + if XNN_UNPREDICTABLE(n <= 60) { + w60 = w59; + } + const int8_t* w61 = w60 + kc; + if XNN_UNPREDICTABLE(n < 62) { + w61 = w60; + } + const int8_t* w62 = w61 + kc; + if XNN_UNPREDICTABLE(n <= 62) { + w62 = w61; + } + const int8_t* w63 = w62 + kc; + if XNN_UNPREDICTABLE(n < 64) { + w63 = w62; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc16 = _mm256_setzero_si256(); + __m256i vacc24 = _mm256_setzero_si256(); + __m256i vacc32 = _mm256_setzero_si256(); + __m256i vacc40 = _mm256_setzero_si256(); + __m256i vacc48 = _mm256_setzero_si256(); + __m256i vacc56 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 64x4 + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w9)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w10)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w11)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w12)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w13)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w14)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) unaligned_load_u32(w15)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w16)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w17)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w18)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w19)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w20)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w21)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w22)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) unaligned_load_u32(w23)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w24)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w25)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w26)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w27)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w28)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w29)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w30)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) unaligned_load_u32(w31)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w32)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w33)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w34)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w35)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w36)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w37)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w38)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) unaligned_load_u32(w39)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w40)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w41)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w42)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w43)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w44)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w45)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w46)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) unaligned_load_u32(w47)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w48)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w49)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w50)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w51)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w52)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w53)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w54)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) unaligned_load_u32(w55)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w56)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w57)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w58)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w59)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w60)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w61)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w62)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) unaligned_load_u32(w63)), 0x80); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + w16 += 4; + w17 += 4; + w18 += 4; + w19 += 4; + w20 += 4; + w21 += 4; + w22 += 4; + w23 += 4; + w24 += 4; + w25 += 4; + w26 += 4; + w27 += 4; + w28 += 4; + w29 += 4; + w30 += 4; + w31 += 4; + w32 += 4; + w33 += 4; + w34 += 4; + w35 += 4; + w36 += 4; + w37 += 4; + w38 += 4; + w39 += 4; + w40 += 4; + w41 += 4; + w42 += 4; + w43 += 4; + w44 += 4; + w45 += 4; + w46 += 4; + w47 += 4; + w48 += 4; + w49 += 4; + w50 += 4; + w51 += 4; + w52 += 4; + w53 += 4; + w54 += 4; + w55 += 4; + w56 += 4; + w57 += 4; + w58 += 4; + w59 += 4; + w60 += 4; + w61 += 4; + w62 += 4; + w63 += 4; + out += 256; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + __m256i v8 = _mm256_set1_epi32((int32_t) safe_load_u32(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w9, k)), 0x02); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w10, k)), 0x04); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w11, k)), 0x08); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w12, k)), 0x10); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w13, k)), 0x20); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w14, k)), 0x40); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi32((int32_t) safe_load_u32(w15, k)), 0x80); + __m256i v16 = _mm256_set1_epi32((int32_t) safe_load_u32(w16, k)); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w17, k)), 0x02); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w18, k)), 0x04); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w19, k)), 0x08); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w20, k)), 0x10); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w21, k)), 0x20); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w22, k)), 0x40); + v16 = _mm256_blend_epi32(v16, _mm256_set1_epi32((int32_t) safe_load_u32(w23, k)), 0x80); + __m256i v24 = _mm256_set1_epi32((int32_t) safe_load_u32(w24, k)); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w25, k)), 0x02); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w26, k)), 0x04); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w27, k)), 0x08); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w28, k)), 0x10); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w29, k)), 0x20); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w30, k)), 0x40); + v24 = _mm256_blend_epi32(v24, _mm256_set1_epi32((int32_t) safe_load_u32(w31, k)), 0x80); + __m256i v32 = _mm256_set1_epi32((int32_t) safe_load_u32(w32, k)); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w33, k)), 0x02); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w34, k)), 0x04); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w35, k)), 0x08); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w36, k)), 0x10); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w37, k)), 0x20); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w38, k)), 0x40); + v32 = _mm256_blend_epi32(v32, _mm256_set1_epi32((int32_t) safe_load_u32(w39, k)), 0x80); + __m256i v40 = _mm256_set1_epi32((int32_t) safe_load_u32(w40, k)); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w41, k)), 0x02); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w42, k)), 0x04); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w43, k)), 0x08); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w44, k)), 0x10); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w45, k)), 0x20); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w46, k)), 0x40); + v40 = _mm256_blend_epi32(v40, _mm256_set1_epi32((int32_t) safe_load_u32(w47, k)), 0x80); + __m256i v48 = _mm256_set1_epi32((int32_t) safe_load_u32(w48, k)); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w49, k)), 0x02); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w50, k)), 0x04); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w51, k)), 0x08); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w52, k)), 0x10); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w53, k)), 0x20); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w54, k)), 0x40); + v48 = _mm256_blend_epi32(v48, _mm256_set1_epi32((int32_t) safe_load_u32(w55, k)), 0x80); + __m256i v56 = _mm256_set1_epi32((int32_t) safe_load_u32(w56, k)); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w57, k)), 0x02); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w58, k)), 0x04); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w59, k)), 0x08); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w60, k)), 0x10); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w61, k)), 0x20); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w62, k)), 0x40); + v56 = _mm256_blend_epi32(v56, _mm256_set1_epi32((int32_t) safe_load_u32(w63, k)), 0x80); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + w16 += k; + w17 += k; + w18 += k; + w19 += k; + w20 += k; + w21 += k; + w22 += k; + w23 += k; + w24 += k; + w25 += k; + w26 += k; + w27 += k; + w28 += k; + w29 += k; + w30 += k; + w31 += k; + w32 += k; + w33 += k; + w34 += k; + w35 += k; + w36 += k; + w37 += k; + w38 += k; + w39 += k; + w40 += k; + w41 += k; + w42 += k; + w43 += k; + w44 += k; + w45 += k; + w46 += k; + w47 += k; + w48 += k; + w49 += k; + w50 += k; + w51 += k; + w52 += k; + w53 += k; + w54 += k; + w55 += k; + w56 += k; + w57 += k; + w58 += k; + w59 += k; + w60 += k; + w61 += k; + w62 += k; + w63 += k; + + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc16 = _mm256_dpbusd_epi32(vacc16, vone, v16); + vacc24 = _mm256_dpbusd_epi32(vacc24, vone, v24); + vacc32 = _mm256_dpbusd_epi32(vacc32, vone, v32); + vacc40 = _mm256_dpbusd_epi32(vacc40, vone, v40); + vacc48 = _mm256_dpbusd_epi32(vacc48, vone, v48); + vacc56 = _mm256_dpbusd_epi32(vacc56, vone, v56); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v8); + _mm256_storeu_si256((__m256i *)&out[64], v16); + _mm256_storeu_si256((__m256i *)&out[96], v24); + _mm256_storeu_si256((__m256i *)&out[128], v32); + _mm256_storeu_si256((__m256i *)&out[160], v40); + _mm256_storeu_si256((__m256i *)&out[192], v48); + _mm256_storeu_si256((__m256i *)&out[224], v56); + + out += 256; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vksum8 = _mm256_mullo_epi32(vacc8, vzeropoint); + __m256i vksum16 = _mm256_mullo_epi32(vacc16, vzeropoint); + __m256i vksum24 = _mm256_mullo_epi32(vacc24, vzeropoint); + __m256i vksum32 = _mm256_mullo_epi32(vacc32, vzeropoint); + __m256i vksum40 = _mm256_mullo_epi32(vacc40, vzeropoint); + __m256i vksum48 = _mm256_mullo_epi32(vacc48, vzeropoint); + __m256i vksum56 = _mm256_mullo_epi32(vacc56, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + __m256i vpack16 = _mm256_loadu_si256((const __m256i*) (packed_b + 16)); + __m256i vpack24 = _mm256_loadu_si256((const __m256i*) (packed_b + 24)); + __m256i vpack32 = _mm256_loadu_si256((const __m256i*) (packed_b + 32)); + __m256i vpack40 = _mm256_loadu_si256((const __m256i*) (packed_b + 40)); + __m256i vpack48 = _mm256_loadu_si256((const __m256i*) (packed_b + 48)); + __m256i vpack56 = _mm256_loadu_si256((const __m256i*) (packed_b + 56)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + vpack16 = _mm256_sub_epi32(vpack16, vksum16); + vpack24 = _mm256_sub_epi32(vpack24, vksum24); + vpack32 = _mm256_sub_epi32(vpack32, vksum32); + vpack40 = _mm256_sub_epi32(vpack40, vksum40); + vpack48 = _mm256_sub_epi32(vpack48, vksum48); + vpack56 = _mm256_sub_epi32(vpack56, vksum56); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + _mm256_storeu_si256((__m256i *) (packed_b + 16), vpack16); + _mm256_storeu_si256((__m256i *) (packed_b + 24), vpack24); + _mm256_storeu_si256((__m256i *) (packed_b + 32), vpack32); + _mm256_storeu_si256((__m256i *) (packed_b + 40), vpack40); + _mm256_storeu_si256((__m256i *) (packed_b + 48), vpack48); + _mm256_storeu_si256((__m256i *) (packed_b + 56), vpack56); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni-prfm.c new file mode 100644 index 000000000000..14aafa959eb5 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni-prfm.c @@ -0,0 +1,438 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-gio-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* src, size_t n) { + uint64_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t)bytes[i] << (i * 8); + } + return value; +} + +void xnn_qs8_packw_gemm_gio_ukernel_x8c8__avxvnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t k_stride, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + k_stride; + const int8_t* w2 = w1 + k_stride; + const int8_t* w3 = w2 + k_stride; + const int8_t* w4 = w3 + k_stride; + const int8_t* w5 = w4 + k_stride; + const int8_t* w6 = w5 + k_stride; + const int8_t* w7 = w6 + k_stride; + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m128i v0x01234567 = _mm_loadu_si64(w0); + __m128i v1x01234567 = _mm_loadu_si64(w1); + __m128i v2x01234567 = _mm_loadu_si64(w2); + __m128i v3x01234567 = _mm_loadu_si64(w3); + __m128i v4x01234567 = _mm_loadu_si64(w4); + __m128i v5x01234567 = _mm_loadu_si64(w5); + __m128i v6x01234567 = _mm_loadu_si64(w6); + __m128i v7x01234567 = _mm_loadu_si64(w7); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8 * k_stride; + w1 += 8 * k_stride; + w2 += 8 * k_stride; + w3 += 8 * k_stride; + w4 += 8 * k_stride; + w5 += 8 * k_stride; + w6 += 8 * k_stride; + w7 += 8 * k_stride; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m128i vzero = _mm_setzero_si128(); + __m128i v0x01234567 = _mm_loadu_si64(w0); + __m128i v1x01234567 = vzero; + if (1 < k) { + v1x01234567 = _mm_loadu_si64(w1); + } + __m128i v2x01234567 = vzero; + if (2 < k) { + v2x01234567 = _mm_loadu_si64(w2); + } + __m128i v3x01234567 = vzero; + if (3 < k) { + v3x01234567 = _mm_loadu_si64(w3); + } + __m128i v4x01234567 = vzero; + if (4 < k) { + v4x01234567 = _mm_loadu_si64(w4); + } + __m128i v5x01234567 = vzero; + if (5 < k) { + v5x01234567 = _mm_loadu_si64(w5); + } + __m128i v6x01234567 = vzero; + if (6 < k) { + v6x01234567 = _mm_loadu_si64(w6); + } + __m128i v7x01234567 = vzero; + if (7 < k) { + v7x01234567 = _mm_loadu_si64(w7); + } + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += k * k_stride; + w1 += k * k_stride; + w2 += k * k_stride; + w3 += k * k_stride; + w4 += k * k_stride; + w5 += k * k_stride; + w6 += k * k_stride; + w7 += k * k_stride; + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w0 - kc * k_stride + 8; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(int32_t); + + const int8_t* w1 = w0 + k_stride; + const int8_t* w2 = w1 + k_stride; + const int8_t* w3 = w2 + k_stride; + const int8_t* w4 = w3 + k_stride; + const int8_t* w5 = w4 + k_stride; + const int8_t* w6 = w5 + k_stride; + const int8_t* w7 = w6 + k_stride; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m128i v0x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w0, n)); + __m128i v1x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w1, n)); + __m128i v2x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w2, n)); + __m128i v3x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w3, n)); + __m128i v4x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w4, n)); + __m128i v5x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w5, n)); + __m128i v6x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w6, n)); + __m128i v7x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w7, n)); + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8 * k_stride; + w1 += 8 * k_stride; + w2 += 8 * k_stride; + w3 += 8 * k_stride; + w4 += 8 * k_stride; + w5 += 8 * k_stride; + w6 += 8 * k_stride; + w7 += 8 * k_stride; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m128i vzero = _mm_setzero_si128(); + __m128i v0x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w0, n)); + __m128i v1x01234567 = vzero; + if (1 < k) { + v1x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w1, n)); + } + __m128i v2x01234567 = vzero; + if (2 < k) { + v2x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w2, n)); + } + __m128i v3x01234567 = vzero; + if (3 < k) { + v3x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w3, n)); + } + __m128i v4x01234567 = vzero; + if (4 < k) { + v4x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w4, n)); + } + __m128i v5x01234567 = vzero; + if (5 < k) { + v5x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w5, n)); + } + __m128i v6x01234567 = vzero; + if (6 < k) { + v6x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w6, n)); + } + __m128i v7x01234567 = vzero; + if (7 < k) { + v7x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w7, n)); + } + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += k * k_stride; + w1 += k * k_stride; + w2 += k * k_stride; + w3 += k * k_stride; + w4 += k * k_stride; + w5 += k * k_stride; + w6 += k * k_stride; + w7 += k * k_stride; + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w0 - kc * k_stride + 8; + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni.c new file mode 100644 index 000000000000..5e9b6354a144 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-gio-avxvnni.c @@ -0,0 +1,373 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-gio-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* src, size_t n) { + uint64_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t)bytes[i] << (i * 8); + } + return value; +} + +void xnn_qs8_packw_gemm_gio_ukernel_x8c8__avxvnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t k_stride, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + k_stride; + const int8_t* w2 = w1 + k_stride; + const int8_t* w3 = w2 + k_stride; + const int8_t* w4 = w3 + k_stride; + const int8_t* w5 = w4 + k_stride; + const int8_t* w6 = w5 + k_stride; + const int8_t* w7 = w6 + k_stride; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m128i v0x01234567 = _mm_loadu_si64(w0); + __m128i v1x01234567 = _mm_loadu_si64(w1); + __m128i v2x01234567 = _mm_loadu_si64(w2); + __m128i v3x01234567 = _mm_loadu_si64(w3); + __m128i v4x01234567 = _mm_loadu_si64(w4); + __m128i v5x01234567 = _mm_loadu_si64(w5); + __m128i v6x01234567 = _mm_loadu_si64(w6); + __m128i v7x01234567 = _mm_loadu_si64(w7); + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8 * k_stride; + w1 += 8 * k_stride; + w2 += 8 * k_stride; + w3 += 8 * k_stride; + w4 += 8 * k_stride; + w5 += 8 * k_stride; + w6 += 8 * k_stride; + w7 += 8 * k_stride; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m128i vzero = _mm_setzero_si128(); + __m128i v0x01234567 = _mm_loadu_si64(w0); + __m128i v1x01234567 = vzero; + if (1 < k) { + v1x01234567 = _mm_loadu_si64(w1); + } + __m128i v2x01234567 = vzero; + if (2 < k) { + v2x01234567 = _mm_loadu_si64(w2); + } + __m128i v3x01234567 = vzero; + if (3 < k) { + v3x01234567 = _mm_loadu_si64(w3); + } + __m128i v4x01234567 = vzero; + if (4 < k) { + v4x01234567 = _mm_loadu_si64(w4); + } + __m128i v5x01234567 = vzero; + if (5 < k) { + v5x01234567 = _mm_loadu_si64(w5); + } + __m128i v6x01234567 = vzero; + if (6 < k) { + v6x01234567 = _mm_loadu_si64(w6); + } + __m128i v7x01234567 = vzero; + if (7 < k) { + v7x01234567 = _mm_loadu_si64(w7); + } + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += k * k_stride; + w1 += k * k_stride; + w2 += k * k_stride; + w3 += k * k_stride; + w4 += k * k_stride; + w5 += k * k_stride; + w6 += k * k_stride; + w7 += k * k_stride; + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w0 - kc * k_stride + 8; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((int32_t*) out) = 0; + out += sizeof(int32_t); + } while (--nb != 0); + } + out += (8 - n) * sizeof(int32_t); + + const int8_t* w1 = w0 + k_stride; + const int8_t* w2 = w1 + k_stride; + const int8_t* w3 = w2 + k_stride; + const int8_t* w4 = w3 + k_stride; + const int8_t* w5 = w4 + k_stride; + const int8_t* w6 = w5 + k_stride; + const int8_t* w7 = w6 + k_stride; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m128i v0x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w0, n)); + __m128i v1x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w1, n)); + __m128i v2x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w2, n)); + __m128i v3x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w3, n)); + __m128i v4x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w4, n)); + __m128i v5x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w5, n)); + __m128i v6x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w6, n)); + __m128i v7x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w7, n)); + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8 * k_stride; + w1 += 8 * k_stride; + w2 += 8 * k_stride; + w3 += 8 * k_stride; + w4 += 8 * k_stride; + w5 += 8 * k_stride; + w6 += 8 * k_stride; + w7 += 8 * k_stride; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m128i vzero = _mm_setzero_si128(); + __m128i v0x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w0, n)); + __m128i v1x01234567 = vzero; + if (1 < k) { + v1x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w1, n)); + } + __m128i v2x01234567 = vzero; + if (2 < k) { + v2x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w2, n)); + } + __m128i v3x01234567 = vzero; + if (3 < k) { + v3x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w3, n)); + } + __m128i v4x01234567 = vzero; + if (4 < k) { + v4x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w4, n)); + } + __m128i v5x01234567 = vzero; + if (5 < k) { + v5x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w5, n)); + } + __m128i v6x01234567 = vzero; + if (6 < k) { + v6x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w6, n)); + } + __m128i v7x01234567 = vzero; + if (7 < k) { + v7x01234567 = _mm_set1_epi64x((int64_t) safe_load_u64(w7, n)); + } + + __m128i v01x01234567 = _mm_unpacklo_epi8(v0x01234567, v1x01234567); + __m128i v23x01234567 = _mm_unpacklo_epi8(v2x01234567, v3x01234567); + __m128i v45x01234567 = _mm_unpacklo_epi8(v4x01234567, v5x01234567); + __m128i v67x01234567 = _mm_unpacklo_epi8(v6x01234567, v7x01234567); + + __m128i v0123x0123 = _mm_unpacklo_epi16(v01x01234567, v23x01234567); + __m128i v0123x4567 = _mm_unpackhi_epi16(v01x01234567, v23x01234567); + __m128i v4567x0123 = _mm_unpacklo_epi16(v45x01234567, v67x01234567); + __m128i v4567x4567 = _mm_unpackhi_epi16(v45x01234567, v67x01234567); + + __m128i v01234567x01 = _mm_unpacklo_epi32(v0123x0123, v4567x0123); + __m128i v01234567x23 = _mm_unpackhi_epi32(v0123x0123, v4567x0123); + __m128i v01234567x45 = _mm_unpacklo_epi32(v0123x4567, v4567x4567); + __m128i v01234567x67 = _mm_unpackhi_epi32(v0123x4567, v4567x4567); + + __m256i v0 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x01), v01234567x23, 1); + __m256i v4 = _mm256_inserti128_si256(_mm256_castsi128_si256(v01234567x45), v01234567x67, 1); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += k * k_stride; + w1 += k * k_stride; + w2 += k * k_stride; + w3 += k * k_stride; + w4 += k * k_stride; + w5 += k * k_stride; + w6 += k * k_stride; + w7 += k * k_stride; + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w0 - kc * k_stride + 8; + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c index 37348495a8dc..785aafc77545 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx2-madd.c @@ -19,13 +19,23 @@ // AVXVNNI replacement that uses vpmaddubsw. // u7 is vone. s8 is int8 weights. static XNN_INTRINSIC -__m256i _mm256_dpbusd_epi32_madd(__m256i i32, const __m256i u7, const __m256i s8) { +__m256i mm256_dpbusd_epi32_madd(__m256i i32, const __m256i u7, const __m256i s8) { const __m256i vone = _mm256_set1_epi16(1); const __m256i i16 = _mm256_maddubs_epi16(u7, s8); // u7 * s8 = s16 const __m256i v = _mm256_madd_epi16(i16, vone); // convert 16 bits to 32 bits return _mm256_add_epi32(i32, v); } +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( size_t g, @@ -39,7 +49,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -49,19 +59,27 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -72,13 +90,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); @@ -105,23 +116,23 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0_0); - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0_1); - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0_2); - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0_3); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4_0); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4_1); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4_2); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4_3); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_0); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_1); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_2); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_3); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_0); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_1); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_2); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_3); _mm256_storeu_si256((__m256i *)&out[0], v0_0); _mm256_storeu_si256((__m256i *)&out[32], v4_0); @@ -154,8 +165,8 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4); _mm256_storeu_si256((__m256i *)&out[0], v0); _mm256_storeu_si256((__m256i *)&out[32], v4); @@ -175,18 +186,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -197,8 +204,8 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( w6 += k; w7 += k; - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4); _mm256_storeu_si256((__m256i *)&out[0], v0); _mm256_storeu_si256((__m256i *)&out[32], v4); @@ -217,25 +224,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -265,11 +257,83 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( w7 = w6; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_0); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_1); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_2); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0_3); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_0); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_1); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_2); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -280,8 +344,8 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4); _mm256_storeu_si256((__m256i *)&out[0], v0); _mm256_storeu_si256((__m256i *)&out[32], v4); @@ -301,18 +365,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -323,8 +383,8 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( w6 += k; w7 += k; - vacc0 = _mm256_dpbusd_epi32_madd(vacc0, vone, v0); - vacc4 = _mm256_dpbusd_epi32_madd(vacc4, vone, v4); + vacc0 = mm256_dpbusd_epi32_madd(vacc0, vone, v0); + vacc4 = mm256_dpbusd_epi32_madd(vacc4, vone, v4); _mm256_storeu_si256((__m256i *)&out[0], v0); _mm256_storeu_si256((__m256i *)&out[32], v4); @@ -338,9 +398,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx2_madd( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c index da219c616b37..3fb2c526c6a5 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,19 +51,27 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -64,13 +82,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -161,14 +172,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 448); xnn_prefetch_to_l1((const int8_t*) w7 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -239,18 +250,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -281,25 +288,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -328,28 +320,148 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( if XNN_UNPREDICTABLE(n < 8) { w7 = w6; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -389,18 +501,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -426,9 +534,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c index 827c33f438db..ad92b04f13bc 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,19 +50,27 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -63,13 +81,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); @@ -96,14 +107,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -166,18 +177,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -208,25 +215,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -256,11 +248,83 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( w7 = w6; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -292,18 +356,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -329,9 +389,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c index 5010f4bbcddb..a6df972026f1 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,19 +51,27 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -64,13 +82,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -161,14 +172,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 448); xnn_prefetch_to_l1((const int8_t*) w7 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -239,18 +250,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -281,25 +288,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -328,28 +320,148 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( if XNN_UNPREDICTABLE(n < 8) { w7 = w6; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -389,18 +501,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -426,9 +534,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c index f85b456c7017..e98e70b7ef82 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,19 +50,27 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -63,13 +81,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); @@ -96,14 +107,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -166,18 +177,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -208,25 +215,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -256,11 +248,83 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( w7 = w6; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -292,18 +356,14 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -329,9 +389,10 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c index 881ea5af2a2e..e05509486f26 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c @@ -14,6 +14,16 @@ #include "xnnpack/packw.h" +XNN_INLINE static v128_t safe_v128_load64_splat(const void* address, size_t n) { + assert(n >= 1 && n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + uint64_t value = (uint64_t) bytes[0]; + for (size_t i = 1; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + + return wasm_u64x2_splat(value); +} void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( size_t g, @@ -27,7 +37,7 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -164,28 +174,26 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( out += 64; } + // Load ealier to avoid unexpected rescheduling. + v128_t vpack0123 = wasm_v128_load(packed_b); + v128_t vpack4567 = wasm_v128_load(packed_b + 4); + // KC remainder 1..KR-1 if (k != 0) { assert(k >= 1 && k <= 7); - const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8); - - const v128_t v0 = wasm_v128_load64_splat(w0); - const v128_t v1 = wasm_v128_load64_splat(w1); - v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); - v01 = wasm_v128_and(v01, vmask); - const v128_t v2 = wasm_v128_load64_splat(w2); - const v128_t v3 = wasm_v128_load64_splat(w3); - v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); - v23 = wasm_v128_and(v23, vmask); - const v128_t v4 = wasm_v128_load64_splat(w4); - const v128_t v5 = wasm_v128_load64_splat(w5); - v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); - v45 = wasm_v128_and(v45, vmask); - const v128_t v6 = wasm_v128_load64_splat(w6); - const v128_t v7 = wasm_v128_load64_splat(w7); - v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); - v67 = wasm_v128_and(v67, vmask); + const v128_t v0 = safe_v128_load64_splat(w0, k); + const v128_t v1 = safe_v128_load64_splat(w1, k); + const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); + const v128_t v2 = safe_v128_load64_splat(w2, k); + const v128_t v3 = safe_v128_load64_splat(w3, k); + const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); + const v128_t v4 = safe_v128_load64_splat(w4, k); + const v128_t v5 = safe_v128_load64_splat(w5, k); + const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); + const v128_t v6 = safe_v128_load64_splat(w6, k); + const v128_t v7 = safe_v128_load64_splat(w7, k); + const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01); vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23); @@ -214,9 +222,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint); vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint); - v128_t vpack0123 = wasm_v128_load(packed_b); - v128_t vpack4567 = wasm_v128_load(packed_b + 4); - wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123)); wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567)); @@ -315,28 +320,26 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( out += 64; } + // Load ealier to avoid unexpected rescheduling. + v128_t vpack0123 = wasm_v128_load(packed_b); + v128_t vpack4567 = wasm_v128_load(packed_b + 4); + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8); - - const v128_t v0 = wasm_v128_load64_splat(w0); - const v128_t v1 = wasm_v128_load64_splat(w1); - v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); - v01 = wasm_v128_and(v01, vmask); - const v128_t v2 = wasm_v128_load64_splat(w2); - const v128_t v3 = wasm_v128_load64_splat(w3); - v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); - v23 = wasm_v128_and(v23, vmask); - const v128_t v4 = wasm_v128_load64_splat(w4); - const v128_t v5 = wasm_v128_load64_splat(w5); - v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); - v45 = wasm_v128_and(v45, vmask); - const v128_t v6 = wasm_v128_load64_splat(w6); - const v128_t v7 = wasm_v128_load64_splat(w7); - v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); - v67 = wasm_v128_and(v67, vmask); + const v128_t v0 = safe_v128_load64_splat(w0, k); + const v128_t v1 = safe_v128_load64_splat(w1, k); + const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); + const v128_t v2 = safe_v128_load64_splat(w2, k); + const v128_t v3 = safe_v128_load64_splat(w3, k); + const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); + const v128_t v4 = safe_v128_load64_splat(w4, k); + const v128_t v5 = safe_v128_load64_splat(w5, k); + const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); + const v128_t v6 = safe_v128_load64_splat(w6, k); + const v128_t v7 = safe_v128_load64_splat(w7, k); + const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01); vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23); @@ -357,9 +360,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint); vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint); - v128_t vpack0123 = wasm_v128_load(packed_b); - v128_t vpack4567 = wasm_v128_load(packed_b + 4); - wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123)); wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567)); diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index 80b956906992..1d6cf6e10930 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -40,9 +40,15 @@ XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxv XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm, 16, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni, 16, 8, 1, 8, 1, 128) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm, 16, 8, 1, 8, 1, 128) + +XNN_QS8_GIO_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_gio_ukernel_x8c8__avxvnni, 8, 8, 1, 8, 1, 0) +XNN_QS8_GIO_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_gio_ukernel_x8c8__avxvnni_prfm, 8, 8, 1, 8, 1, 0) #endif #if XNN_ENABLE_AVX256VNNI && (XNN_ARCH_X86_64 || XNN_ARCH_X86) +XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni, 64, 4, 1, 4, 1, 0) +XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x64c4__avx256vnni_prfm, 64, 4, 1, 4, 1, 0) + XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni, 8, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm, 8, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni, 8, 8, 1, 8, 1, 128) diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni-prfm.c new file mode 100644 index 000000000000..4195347bef8d --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni-prfm.c @@ -0,0 +1,958 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni.c new file mode 100644 index 000000000000..ca225ad6186b --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avx256vnni.c @@ -0,0 +1,669 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avx256vnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni-prfm.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni-prfm.c new file mode 100644 index 000000000000..fea0b3a21255 --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni-prfm.c @@ -0,0 +1,958 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni.c new file mode 100644 index 000000000000..abc91700edf0 --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x16c8-gemm-goi-avxvnni.c @@ -0,0 +1,669 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avxvnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + const __m256i v8_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w8), vkernel_zero_point); // uint4 -> int4 + const __m256i v9_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w9), vkernel_zero_point); // uint4 -> int4 + const __m256i v10_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w10), vkernel_zero_point); // uint4 -> int4 + const __m256i v11_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w11), vkernel_zero_point); // uint4 -> int4 + const __m256i v12_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w12), vkernel_zero_point); // uint4 -> int4 + const __m256i v13_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w13), vkernel_zero_point); // uint4 -> int4 + const __m256i v14_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w14), vkernel_zero_point); // uint4 -> int4 + const __m256i v15_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w15), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + v8_0 = xnn_packed2planar(&vacc8, v8_0, vmask, vone); + v8_1 = xnn_packed2planar(&vacc8, v8_1, vmask, vone); + v8_2 = xnn_packed2planar(&vacc8, v8_2, vmask, vone); + v8_3 = xnn_packed2planar(&vacc8, v8_3, vmask, vone); + v12_0 = xnn_packed2planar(&vacc12, v12_0, vmask, vone); + v12_1 = xnn_packed2planar(&vacc12, v12_1, vmask, vone); + v12_2 = xnn_packed2planar(&vacc12, v12_2, vmask, vone); + v12_3 = xnn_packed2planar(&vacc12, v12_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + w8 += k; + w9 += k; + w10 += k; + w11 += k; + w12 += k; + w13 += k; + w14 += k; + w15 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + v8 = _mm256_xor_si256(v8, vkernel_zero_point); // uint4 -> int4 + v8 = xnn_packed2planar(&vacc8, v8, vmask, vone); + v12 = _mm256_xor_si256(v12, vkernel_zero_point); // uint4 -> int4 + v12 = xnn_packed2planar(&vacc12, v12, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni-prfm.c new file mode 100644 index 000000000000..7f55727caa96 --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni-prfm.c @@ -0,0 +1,570 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c new file mode 100644 index 000000000000..6d30769b7374 --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c @@ -0,0 +1,425 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni-prfm.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni-prfm.c new file mode 100644 index 000000000000..576382508622 --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni-prfm.c @@ -0,0 +1,570 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c new file mode 100644 index 000000000000..9b97fca3cf00 --- /dev/null +++ b/src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c @@ -0,0 +1,425 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + +// Convert a vector from packed nibbles to planar, and accumulate sum +static XNN_INTRINSIC +__m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v01); + *vacc = _mm256_dpbusd_avx_epi32(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); +} + +void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* weights, + const int32_t* bias, + const float* scale, + void* packed_weights, + size_t extra_bytes, + const struct xnn_qs8_qc4w_packing_params* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + 0); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + // Clamp weight pointers for NC remainder + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4 + const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4 + const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4 + const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4 + const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4 + const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4 + const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4 + const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4 + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + v0_0 = xnn_packed2planar(&vacc0, v0_0, vmask, vone); + v0_1 = xnn_packed2planar(&vacc0, v0_1, vmask, vone); + v0_2 = xnn_packed2planar(&vacc0, v0_2, vmask, vone); + v0_3 = xnn_packed2planar(&vacc0, v0_3, vmask, vone); + v4_0 = xnn_packed2planar(&vacc4, v4_0, vmask, vone); + v4_1 = xnn_packed2planar(&vacc4, v4_1, vmask, vone); + v4_2 = xnn_packed2planar(&vacc4, v4_2, vmask, vone); + v4_3 = xnn_packed2planar(&vacc4, v4_3, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + v0 = _mm256_xor_si256(v0, vkernel_zero_point); // uint4 -> int4 + v0 = xnn_packed2planar(&vacc0, v0, vmask, vone); + v4 = _mm256_xor_si256(v4, vkernel_zero_point); // uint4 -> int4 + v4 = xnn_packed2planar(&vacc4, v4, vmask, vone); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights = (const uint8_t*)((intptr_t) weights + nc * kc); + } while (--g != 0); +} diff --git a/src/qs8-qc4w-packw/qs8-qc4w-packw.h b/src/qs8-qc4w-packw/qs8-qc4w-packw.h index 83cb0daba083..fc74085a41e6 100644 --- a/src/qs8-qc4w-packw/qs8-qc4w-packw.h +++ b/src/qs8-qc4w-packw/qs8-qc4w-packw.h @@ -4,6 +4,22 @@ // LICENSE file in the root directory of this source tree. // arch_flags, ukernel, nr, kr, sr, kblock, nr_scale -XNN_UKERNEL(0, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__scalar, 8, 8, 1, 8, 1) -XNN_UKERNEL(0, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 8, 1) -XNN_UKERNEL(0, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x32c8__scalar, 32, 8, 1, 8, 1) +XNN_UKERNEL(0, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__scalar, 8, 8, 1, 16, 1) +XNN_UKERNEL(0, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 16, 1) +XNN_UKERNEL(0, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x32c8__scalar, 32, 8, 1, 16, 1) + +#if XNN_ENABLE_AVXVNNI && (XNN_ARCH_X86_64 || XNN_ARCH_X86) +XNN_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni, 8, 8, 1, 16, 1) +XNN_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm, 8, 8, 1, 16, 1) + +XNN_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avxvnni, 16, 8, 1, 32, 1) +XNN_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm, 16, 8, 1, 32, 1) +#endif + +#if XNN_ENABLE_AVX256VNNI && (XNN_ARCH_X86_64 || XNN_ARCH_X86) +XNN_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni, 8, 8, 1, 16, 1) +XNN_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm, 8, 8, 1, 16, 1) + +XNN_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avx256vnni, 16, 8, 1, 32, 1) +XNN_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_qc4w_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm, 16, 8, 1, 32, 1) +#endif diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c index a21e06ca44d1..18839780fca5 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,7 +48,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +67,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -214,25 +219,36 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx.c index 720d9080cc68..d45bc2e48b91 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x16c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,7 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -61,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -181,25 +186,36 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c index 73717a81ccae..460b5f874137 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,8 +48,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -63,19 +67,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -237,42 +241,53 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx.c index e0164310798b..2b81c94edff3 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,8 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -188,42 +192,53 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c index b67fe3dc2d02..9afa54f37b36 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -43,10 +48,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +67,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -283,76 +285,87 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx.c index eacb030cba32..74a9f92d36c6 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x64c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,10 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -202,76 +204,87 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512amx.c index bab5660bd2c0..cb26da1168fb 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,7 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; + __attribute__((aligned(64))) int32_t res[1][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -61,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -121,10 +126,21 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-avx512amx.c index 2f8dacb75871..5beb686743fc 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,8 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; + __attribute__((aligned(64))) int32_t res[2][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -128,12 +132,23 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x64c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x64c4-minmax-fp32-avx512amx.c index 09a68cfd680b..17bbb18321c2 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x64c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x64c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,10 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -142,16 +144,27 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512amx.c index 9fc1e2355ee1..40b78408bd6d 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x16c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,7 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; + __attribute__((aligned(64))) int32_t res[1][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -61,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -145,16 +150,27 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x32c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x32c4-minmax-fp32-avx512amx.c index b1c4a1f41546..999144bd1401 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x32c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x32c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,8 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; + __attribute__((aligned(64))) int32_t res[2][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -62,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -152,24 +156,35 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x64c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x64c4-minmax-fp32-avx512amx.c index 4ecd9908db57..ca40f2796b9a 100644 --- a/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x64c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-7x64c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -42,10 +47,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx( // TODO: amxintrin.h only provide intrinsics for __x86_64__ // Update if amxintrin changes #if defined(__x86_64__) - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +66,19 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -166,40 +168,51 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx( k -= kremainder * sizeof(int8_t); } + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx-prfm.c index bcac8a3b865d..b8b397894f0e 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -45,7 +50,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +70,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -428,25 +433,36 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx.c index 738ffc27a3d8..92b7154ac154 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x16c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,7 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; + __attribute__((aligned(64))) int32_t res[1][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -363,25 +368,36 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx-prfm.c index 3535a5e1606a..8d3e2a08fd30 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -45,8 +50,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -66,19 +70,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -469,42 +473,53 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx.c index d2bd0bdf6360..df05d33c6e26 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x32c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,8 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; + __attribute__((aligned(64))) int32_t res[2][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -372,42 +376,53 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx-prfm.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx-prfm.c index 03b488698c03..01e241d253be 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx-prfm.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx-prfm.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -45,10 +50,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -68,19 +70,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -551,76 +553,87 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx.c index eb5c5c400290..34909b062fc3 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-16x64c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,10 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[16 * 16]; - __attribute__((aligned(64))) int32_t res0[16 * 16]; - __attribute__((aligned(64))) int32_t res1[16 * 16]; - __attribute__((aligned(64))) int32_t res2[16 * 16]; - __attribute__((aligned(64))) int32_t res3[16 * 16]; + __attribute__((aligned(64))) int32_t res[4][16 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -390,76 +392,87 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx( p -= 16 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); - __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 112)); - __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 112)); - __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 112)); - __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 112)); - __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 128)); - __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 128)); - __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 128)); - __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 128)); - __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 144)); - __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 144)); - __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 144)); - __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 144)); - __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 160)); - __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 160)); - __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 160)); - __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 160)); - __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 176)); - __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 176)); - __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 176)); - __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 176)); - __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 192)); - __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 192)); - __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 192)); - __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 192)); - __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 208)); - __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 208)); - __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 208)); - __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 208)); - __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 224)); - __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 224)); - __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 224)); - __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 224)); - __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 240)); - __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 240)); - __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 240)); - __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 240)); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); + __m512i vacc7x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 112)); + __m512i vacc7xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 112)); + __m512i vacc7xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 112)); + __m512i vacc7xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 112)); + __m512i vacc8x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 128)); + __m512i vacc8xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 128)); + __m512i vacc8xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 128)); + __m512i vacc8xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 128)); + __m512i vacc9x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 144)); + __m512i vacc9xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 144)); + __m512i vacc9xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 144)); + __m512i vacc9xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 144)); + __m512i vacc10x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 160)); + __m512i vacc10xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 160)); + __m512i vacc10xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 160)); + __m512i vacc10xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 160)); + __m512i vacc11x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 176)); + __m512i vacc11xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 176)); + __m512i vacc11xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 176)); + __m512i vacc11xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 176)); + __m512i vacc12x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 192)); + __m512i vacc12xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 192)); + __m512i vacc12xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 192)); + __m512i vacc12xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 192)); + __m512i vacc13x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 208)); + __m512i vacc13xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 208)); + __m512i vacc13xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 208)); + __m512i vacc13xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 208)); + __m512i vacc14x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 224)); + __m512i vacc14xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 224)); + __m512i vacc14xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 224)); + __m512i vacc14xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 224)); + __m512i vacc15x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 240)); + __m512i vacc15xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 240)); + __m512i vacc15xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 240)); + __m512i vacc15xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 240)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-fp32-avx512amx.c index bccfc3134671..2b4036bc76b1 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x16c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,7 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; + __attribute__((aligned(64))) int32_t res[1][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -139,10 +144,21 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x32c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x32c4-minmax-fp32-avx512amx.c index 206345ff9a5a..b5365dc9d8a8 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x32c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x32c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,8 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x32c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; + __attribute__((aligned(64))) int32_t res[2][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -146,12 +150,23 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x32c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); - // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. + // Add tile to bias + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x64c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x64c4-minmax-fp32-avx512amx.c index 0f46fa21752f..f7a19a2299d6 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x64c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x64c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,10 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[1 * 16]; - __attribute__((aligned(64))) int32_t res0[1 * 16]; - __attribute__((aligned(64))) int32_t res1[1 * 16]; - __attribute__((aligned(64))) int32_t res2[1 * 16]; - __attribute__((aligned(64))) int32_t res3[1 * 16]; + __attribute__((aligned(64))) int32_t res[4][1 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -160,16 +162,27 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx( p -= 1 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-fp32-avx512amx.c index 4e44558cae2d..abd47ad746af 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x16c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,7 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; + __attribute__((aligned(64))) int32_t res[1][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -64,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -237,16 +242,27 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x32c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x32c4-minmax-fp32-avx512amx.c index c82fedf4bfa5..9b4e80900f3b 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x32c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x32c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,8 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x32c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; + __attribute__((aligned(64))) int32_t res[2][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -65,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x32c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -246,24 +250,35 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x32c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x64c4-minmax-fp32-avx512amx.c b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x64c4-minmax-fp32-avx512amx.c index d433e595ce54..bc5021013ee4 100644 --- a/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x64c4-minmax-fp32-avx512amx.c +++ b/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-7x64c4-minmax-fp32-avx512amx.c @@ -8,6 +8,11 @@ // LICENSE file in the root directory of this source tree. #include +#if defined(__has_feature) + #if __has_feature(memory_sanitizer) + #include + #endif +#endif #include @@ -44,10 +49,7 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x64c4__avx512amx( // Update if amxintrin changes #if defined(__x86_64__) __attribute__((aligned(64))) int32_t vintile[7 * 16]; - __attribute__((aligned(64))) int32_t res0[7 * 16]; - __attribute__((aligned(64))) int32_t res1[7 * 16]; - __attribute__((aligned(64))) int32_t res2[7 * 16]; - __attribute__((aligned(64))) int32_t res3[7 * 16]; + __attribute__((aligned(64))) int32_t res[4][7 * 16]; kc = round_up_po2(kc, 4 * sizeof(int8_t)); const size_t kremainder = (kc & 63) ? (kc & 63) : 64; @@ -67,19 +69,19 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x64c4__avx512amx( // Load tile configuration __attribute__((aligned(64))) struct __tile_config tile_data = {0}; tile_data.palette_id = 1; - tile_data.rows[0] = mr; // tmm0 = res 0 - tile_data.rows[1] = mr; // tmm1 = res 1 - tile_data.rows[2] = mr; // tmm2 = res 2 - tile_data.rows[3] = mr; // tmm3 = res 3 + tile_data.rows[0] = mr; // tmm0 = res[0] + tile_data.rows[1] = mr; // tmm1 = res[1] + tile_data.rows[2] = mr; // tmm2 = res[2] + tile_data.rows[3] = mr; // tmm3 = res[3] tile_data.rows[4] = mr; // tmm4 = input tile_data.rows[5] = 16; // tmm5 = weights tile_data.rows[6] = mr; // tmm6 = input remainder tile_data.rows[7] = kremainder >> 2; // tmm7 = weights remainder - tile_data.colsb[0] = 64; // tmm0 = res 0 - tile_data.colsb[1] = 64; // tmm1 = res 1 - tile_data.colsb[2] = 64; // tmm2 = res 1 - tile_data.colsb[3] = 64; // tmm3 = res 1 + tile_data.colsb[0] = 64; // tmm0 = res[0] + tile_data.colsb[1] = 64; // tmm1 = res[1] + tile_data.colsb[2] = 64; // tmm2 = res[2] + tile_data.colsb[3] = 64; // tmm3 = res[3] tile_data.colsb[4] = 64; // tmm4 = input tile_data.colsb[5] = 64; // tmm5 = weights tile_data.colsb[6] = kremainder; // tmm6 = input remainder @@ -264,40 +266,51 @@ void xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x64c4__avx512amx( p -= 7 * sizeof(void*); } while (p != 0); + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 tile at a time (16 registers) + _tile_stored(0, &res[0][0], 64); + _tile_stored(1, &res[1][0], 64); + _tile_stored(2, &res[2][0], 64); + _tile_stored(3, &res[3][0], 64); + + // TODO: Fix msan for AMX + #if defined(__has_feature) + #if __has_feature(memory_sanitizer) + __msan_unpoison(res, sizeof(res)); + #endif + #endif + + // TODO: Instead of processing up to 4 tiles (16x64) consider + // quantizing 1 row at a time. // Add tile to bias - _tile_stored(0, res0, 64); - _tile_stored(1, res1, 64); - _tile_stored(2, res2, 64); - _tile_stored(3, res3, 64); - - __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 0)); - __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 0)); - __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 0)); - __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 0)); - __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 16)); - __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 16)); - __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 16)); - __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 16)); - __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 32)); - __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 32)); - __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 32)); - __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 32)); - __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 48)); - __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 48)); - __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 48)); - __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 48)); - __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 64)); - __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 64)); - __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 64)); - __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 64)); - __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 80)); - __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 80)); - __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 80)); - __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 80)); - __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(res0 + 96)); - __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(res1 + 96)); - __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(res2 + 96)); - __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(res3 + 96)); + __m512i vacc0x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 0)); + __m512i vacc0xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 0)); + __m512i vacc0xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 0)); + __m512i vacc0xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 0)); + __m512i vacc1x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 16)); + __m512i vacc1xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 16)); + __m512i vacc1xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 16)); + __m512i vacc1xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 16)); + __m512i vacc2x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 32)); + __m512i vacc2xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 32)); + __m512i vacc2xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 32)); + __m512i vacc2xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 32)); + __m512i vacc3x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 48)); + __m512i vacc3xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 48)); + __m512i vacc3xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 48)); + __m512i vacc3xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 48)); + __m512i vacc4x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 64)); + __m512i vacc4xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 64)); + __m512i vacc4xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 64)); + __m512i vacc4xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 64)); + __m512i vacc5x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 80)); + __m512i vacc5xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 80)); + __m512i vacc5xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 80)); + __m512i vacc5xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 80)); + __m512i vacc6x0123456789ABCDEF = _mm512_add_epi32(vksum0123456789ABCDEF, _mm512_load_epi32(&res[0][0] + 96)); + __m512i vacc6xGHIJKLMNOPQRSTUV = _mm512_add_epi32(vksumGHIJKLMNOPQRSTUV, _mm512_load_epi32(&res[1][0] + 96)); + __m512i vacc6xWXYZabcdefghijkl = _mm512_add_epi32(vksumWXYZabcdefghijkl, _mm512_load_epi32(&res[2][0] + 96)); + __m512i vacc6xmnopqrstuvwxyz01 = _mm512_add_epi32(vksummnopqrstuvwxyz01, _mm512_load_epi32(&res[3][0] + 96)); __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); __m512 vscaled0xGHIJKLMNOPQRSTUV = _mm512_cvtepi32_ps(vacc0xGHIJKLMNOPQRSTUV); diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c index 0d5731782fb0..7b706f6374c2 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,32 +51,19 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( const int8_t* w13 = w12 + kc; const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -254,22 +265,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w14 + 448); xnn_prefetch_to_l1((const int8_t*) w15 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -392,28 +403,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -462,25 +467,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -541,46 +531,263 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( if XNN_UNPREDICTABLE(n < 16) { w15 = w14; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); - xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); xnn_prefetch_to_l1((const int8_t*) w8 + 64); - xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); xnn_prefetch_to_l1((const int8_t*) w9 + 64); - xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); xnn_prefetch_to_l1((const int8_t*) w10 + 64); - xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); xnn_prefetch_to_l1((const int8_t*) w11 + 64); - xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); xnn_prefetch_to_l1((const int8_t*) w12 + 64); - xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); xnn_prefetch_to_l1((const int8_t*) w13 + 64); - xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); xnn_prefetch_to_l1((const int8_t*) w14 + 64); - xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -648,28 +855,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -713,9 +914,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c index 5a5d7dfc774e..e794acd3aab0 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,32 +50,19 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); @@ -125,22 +136,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -247,28 +258,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -317,25 +322,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -397,13 +387,134 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( w15 = w14; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -455,28 +566,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -520,9 +625,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c index f1339f0f4698..04e36c372048 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,32 +51,19 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( const int8_t* w13 = w12 + kc; const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -254,22 +265,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w14 + 448); xnn_prefetch_to_l1((const int8_t*) w15 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -392,28 +403,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -462,25 +467,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -541,46 +531,263 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( if XNN_UNPREDICTABLE(n < 16) { w15 = w14; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); - xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); xnn_prefetch_to_l1((const int8_t*) w8 + 64); - xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); xnn_prefetch_to_l1((const int8_t*) w9 + 64); - xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); xnn_prefetch_to_l1((const int8_t*) w10 + 64); - xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); xnn_prefetch_to_l1((const int8_t*) w11 + 64); - xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); xnn_prefetch_to_l1((const int8_t*) w12 + 64); - xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); xnn_prefetch_to_l1((const int8_t*) w13 + 64); - xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); xnn_prefetch_to_l1((const int8_t*) w14 + 64); - xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -648,28 +855,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -713,9 +914,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c index 66f1bd09ac45..60a30dc1ddc6 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,32 +50,19 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 16 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -82,6 +79,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); @@ -125,22 +136,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -247,28 +258,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -317,25 +322,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -397,13 +387,134 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( w15 = w14; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); __m256i vacc8 = _mm256_setzero_si256(); __m256i vacc12 = _mm256_setzero_si256(); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_0); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_1); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_2); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8_3); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_0); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_1); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_2); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -455,28 +566,22 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -520,9 +625,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( vpack8 = _mm256_sub_epi32(vpack8, vksum8); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c index 8750b5a7994e..5c3804395d95 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,19 +51,27 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -64,13 +82,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -161,14 +172,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 448); xnn_prefetch_to_l1((const int8_t*) w7 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -239,18 +250,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -281,25 +288,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -328,28 +320,148 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( if XNN_UNPREDICTABLE(n < 8) { w7 = w6; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -389,18 +501,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -426,9 +534,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c index 7d2631167503..3ddce09366b0 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,19 +50,27 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -63,13 +81,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); @@ -96,14 +107,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); @@ -166,18 +177,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -208,25 +215,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -256,11 +248,83 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( w7 = w6; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -292,18 +356,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -329,9 +389,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c index 7b88506158b0..86afaae481af 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,19 +51,27 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -64,13 +82,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -161,14 +172,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 448); xnn_prefetch_to_l1((const int8_t*) w7 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -239,18 +250,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -281,25 +288,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -328,28 +320,148 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( if XNN_UNPREDICTABLE(n < 8) { w7 = w6; } - xnn_prefetch_to_l1((const int8_t*) w0); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -389,18 +501,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -426,9 +534,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c index 9b8f2b150f88..fde9c6d74bde 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( size_t g, @@ -30,7 +40,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,19 +50,27 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const int32_t* b = (const int32_t*) bias; const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128)); do { // NC main loop multiple of 8 const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); @@ -63,13 +81,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( } out += 8 * sizeof(int32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); @@ -96,14 +107,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); @@ -166,18 +177,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -208,25 +215,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - int32_t* packed_b = (int32_t*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((int32_t*) out) = *b++; - out += sizeof(int32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((int32_t*) out) = 0; - out += sizeof(int32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(int32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -256,11 +248,83 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( w7 = w6; } + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((int32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + __m256i vacc0 = _mm256_setzero_si256(); __m256i vacc4 = _mm256_setzero_si256(); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_2); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0_3); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_1); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_2); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4_3); + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -292,18 +356,14 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -329,9 +389,10 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); vpack0 = _mm256_sub_epi32(vpack0, vksum0); _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c index f479706779fc..d103fa43842d 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c @@ -14,6 +14,16 @@ #include "xnnpack/packw.h" +XNN_INLINE static v128_t safe_v128_load64_splat(const void* address, size_t n) { + assert(n >= 1 && n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + uint64_t value = (uint64_t) bytes[0]; + for (size_t i = 1; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + + return wasm_u64x2_splat(value); +} void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( size_t g, @@ -27,7 +37,7 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -164,28 +174,26 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( out += 64; } + // Load ealier to avoid unexpected rescheduling. + v128_t vpack0123 = wasm_v128_load(packed_b); + v128_t vpack4567 = wasm_v128_load(packed_b + 4); + // KC remainder 1..KR-1 if (k != 0) { assert(k >= 1 && k <= 7); - const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8); - - const v128_t v0 = wasm_v128_load64_splat(w0); - const v128_t v1 = wasm_v128_load64_splat(w1); - v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); - v01 = wasm_v128_and(v01, vmask); - const v128_t v2 = wasm_v128_load64_splat(w2); - const v128_t v3 = wasm_v128_load64_splat(w3); - v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); - v23 = wasm_v128_and(v23, vmask); - const v128_t v4 = wasm_v128_load64_splat(w4); - const v128_t v5 = wasm_v128_load64_splat(w5); - v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); - v45 = wasm_v128_and(v45, vmask); - const v128_t v6 = wasm_v128_load64_splat(w6); - const v128_t v7 = wasm_v128_load64_splat(w7); - v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); - v67 = wasm_v128_and(v67, vmask); + const v128_t v0 = safe_v128_load64_splat(w0, k); + const v128_t v1 = safe_v128_load64_splat(w1, k); + const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); + const v128_t v2 = safe_v128_load64_splat(w2, k); + const v128_t v3 = safe_v128_load64_splat(w3, k); + const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); + const v128_t v4 = safe_v128_load64_splat(w4, k); + const v128_t v5 = safe_v128_load64_splat(w5, k); + const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); + const v128_t v6 = safe_v128_load64_splat(w6, k); + const v128_t v7 = safe_v128_load64_splat(w7, k); + const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01); vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23); @@ -214,9 +222,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint); vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint); - v128_t vpack0123 = wasm_v128_load(packed_b); - v128_t vpack4567 = wasm_v128_load(packed_b + 4); - wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123)); wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567)); @@ -315,28 +320,26 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( out += 64; } + // Load ealier to avoid unexpected rescheduling. + v128_t vpack0123 = wasm_v128_load(packed_b); + v128_t vpack4567 = wasm_v128_load(packed_b + 4); + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8); - - const v128_t v0 = wasm_v128_load64_splat(w0); - const v128_t v1 = wasm_v128_load64_splat(w1); - v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); - v01 = wasm_v128_and(v01, vmask); - const v128_t v2 = wasm_v128_load64_splat(w2); - const v128_t v3 = wasm_v128_load64_splat(w3); - v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); - v23 = wasm_v128_and(v23, vmask); - const v128_t v4 = wasm_v128_load64_splat(w4); - const v128_t v5 = wasm_v128_load64_splat(w5); - v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); - v45 = wasm_v128_and(v45, vmask); - const v128_t v6 = wasm_v128_load64_splat(w6); - const v128_t v7 = wasm_v128_load64_splat(w7); - v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); - v67 = wasm_v128_and(v67, vmask); + const v128_t v0 = safe_v128_load64_splat(w0, k); + const v128_t v1 = safe_v128_load64_splat(w1, k); + const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3); + const v128_t v2 = safe_v128_load64_splat(w2, k); + const v128_t v3 = safe_v128_load64_splat(w3, k); + const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3); + const v128_t v4 = safe_v128_load64_splat(w4, k); + const v128_t v5 = safe_v128_load64_splat(w5, k); + const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3); + const v128_t v6 = safe_v128_load64_splat(w6, k); + const v128_t v7 = safe_v128_load64_splat(w7, k); + const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3); vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01); vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23); @@ -357,9 +360,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd( vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint); vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint); - v128_t vpack0123 = wasm_v128_load(packed_b); - v128_t vpack4567 = wasm_v128_load(packed_b + 4); - wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123)); wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567)); diff --git a/src/reference/binary-elementwise.cc b/src/reference/binary-elementwise.cc index f548bc142682..704e87219009 100644 --- a/src/reference/binary-elementwise.cc +++ b/src/reference/binary-elementwise.cc @@ -13,6 +13,7 @@ #include "xnnpack.h" #include "xnnpack/config-types.h" +#include "xnnpack/datatype.h" #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" @@ -66,7 +67,7 @@ void rbinaryc_ukernel_unquantized(size_t batch_size_bytes, const T* a, template const xnn_binary_elementwise_config* get_config(T) { - static_assert(!xnnpack::is_quantized::value); + static_assert(!xnnpack::is_quantized::value, ""); static xnn_binary_elementwise_config config = { (xnn_vbinary_ukernel_fn)binary_ukernel_unquantized, (xnn_vbinary_ukernel_fn)binaryc_ukernel_unquantized, diff --git a/src/reference/packing.cc b/src/reference/packing.cc index 1bd7912db061..c3789abfd45a 100644 --- a/src/reference/packing.cc +++ b/src/reference/packing.cc @@ -23,11 +23,13 @@ #include "xnnpack/unaligned.h" #if XNN_ENABLE_KLEIDIAI - #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" - #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" - #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" - #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" - #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" #endif // XNN_ENABLE_KLEIDIAI @@ -1556,6 +1558,86 @@ void xnn_pack_kai_qs4_weights_and_biases( } } +size_t xnn_packed_stride_kai_qs8_weights_and_biases( + const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, + size_t extra_bytes) { + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + return kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, /*nr=*/1, + kr, sr); +} + +void xnn_pack_kai_qs8_weights_and_biases( + uint32_t flags, const struct xnn_gemm_config* gemm_config, + size_t input_channels, size_t output_channels, size_t groups, + size_t k_stride, const void* accumulator_init, const void* weights, + xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0, + size_t extra_data0_element_size, + xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1, + size_t extra_data1_element_size, void* packed_weights_ptr, + const void* params) { + const uint32_t nr = gemm_config->nr; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + const struct xnn_qs8_qc8w_packing_params* xnn_params = + reinterpret_cast(params); + + // Repack the packing params. + struct kai_rhs_pack_qsi8cx_params kai_params; + kai_params.lhs_zero_point = xnn_params->input_zero_point; + kai_params.scale_multiplier = xnn_params->scale_multiplier; + + const size_t weights_group_stride = + sizeof(int8_t) * input_channels * output_channels; + const size_t n_stride = round_up(output_channels, nr); + const size_t packed_weights_group_stride = + n_stride * xnn_packed_stride_kai_qs8_weights_and_biases( + gemm_config, input_channels, k_stride, + extra_data0_element_size + extra_data1_element_size); + + if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) { + for (size_t group = 0; group < groups; group++) { + kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon( + /*groups=*/1, output_channels, input_channels, nr, kr, sr, + /*rhs=*/ + reinterpret_cast((uintptr_t)weights + + group * weights_group_stride), + /*bias=*/ + extra_data0 ? reinterpret_cast(extra_data0) + + group * output_channels + : NULL, + /*scale=*/ + extra_data1 ? reinterpret_cast(extra_data1) + + group * output_channels + : NULL, + /*rhs_packed=*/ + (void*)((uintptr_t)packed_weights_ptr + + group * packed_weights_group_stride), + /*extra_bytes=*/0, &kai_params); + } + } else { + for (size_t group = 0; group < groups; group++) { + kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon( + /*groups=*/1, output_channels, input_channels, nr, kr, sr, + /*rhs=*/ + reinterpret_cast((uintptr_t)weights + + group * weights_group_stride), + /*bias=*/ + extra_data0 ? reinterpret_cast(extra_data0) + + group * output_channels + : NULL, + /*scale=*/ + extra_data1 ? reinterpret_cast(extra_data1) + + group * output_channels + : NULL, + /*rhs_packed=*/ + (void*)((uintptr_t)packed_weights_ptr + + group * packed_weights_group_stride), + /*extra_bytes=*/0, &kai_params); + } + } +} + size_t xnn_packed_stride_kai_f32_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, size_t extra_bytes) { diff --git a/src/reference/unary-elementwise.cc b/src/reference/unary-elementwise.cc index 904ba0eaf130..51aa76fab7ce 100644 --- a/src/reference/unary-elementwise.cc +++ b/src/reference/unary-elementwise.cc @@ -128,7 +128,7 @@ struct ConvertOp { } }; -#ifdef XNN_HAVE_FLOAT16 +#if XNN_HAVE_FLOAT16 template <> struct ConvertOp { explicit ConvertOp(const xnn_unary_uparams*) {} diff --git a/src/runtime.c b/src/runtime.c index bc29215b0bff..68a4f800256c 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -266,7 +266,8 @@ static enum xnn_status initialize_workspace_values( // Value is purely internal to the runtime, allocate it in the workspace. value->data = (void*) ((uintptr_t) runtime->workspace->data + persistent_size + mem_alloc_tracker->usage[i].alloc_offset); - if (value->datatype == xnn_datatype_qdint8) { + if (value->datatype == xnn_datatype_qdint8 || + value->datatype == xnn_datatype_qduint8) { value->quantization.dynamic_params = (void*) ((uintptr_t) runtime->workspace->data + persistent_size + mem_alloc_tracker->usage[i].alloc_offset + xnn_tensor_get_rounded_size(value)); @@ -310,7 +311,8 @@ static enum xnn_status initialize_workspace_values( if (value->data != NULL) { // Data can be null as the runtime using this workspace might not have been set up. value->data = (void*) ((uintptr_t) value->data + workspace_data_delta); - if (value->datatype == xnn_datatype_qdint8) { + if (value->datatype == xnn_datatype_qdint8 || + value->datatype == xnn_datatype_qduint8) { value->quantization.dynamic_params = (void*) ((uintptr_t) value->quantization.dynamic_params + workspace_data_delta); } @@ -634,6 +636,8 @@ enum xnn_status xnn_create_runtime_v4( } } + runtime->threadpool = threadpool; + #ifdef XNN_SLINKY_ENABLED // If compiling with XNN_SLINKY_ENABLED defined, assume we always // want Slinky enabled, regardless of the runtime flag @@ -673,8 +677,6 @@ enum xnn_status xnn_create_runtime_v4( runtime->profiling = true; } - runtime->threadpool = threadpool; - *runtime_out = runtime; return xnn_status_success; @@ -700,7 +702,7 @@ enum xnn_status xnn_plan_memory( if (value->allocation_type == xnn_allocation_type_workspace) { // Value is purely internal to the runtime, and must be allocated in its workspace. size_t tensor_size = xnn_tensor_get_rounded_size(value); - if (value->datatype == xnn_datatype_qdint8) { + if (value->datatype == xnn_datatype_qdint8 || value->datatype == xnn_datatype_qduint8) { tensor_size += xnn_tensor_get_rounded_dynamic_quant_param_size(value); } xnn_add_value_allocation_tracker(&mem_alloc_tracker, i, tensor_size); diff --git a/src/s32-f32-vcvt/gen/s32-f32-vcvt-avx2.c b/src/s32-f32-vcvt/gen/s32-f32-vcvt-avx2.c deleted file mode 100644 index 755c7f58702d..000000000000 --- a/src/s32-f32-vcvt/gen/s32-f32-vcvt-avx2.c +++ /dev/null @@ -1,208 +0,0 @@ -// Auto-generated file. Do not edit! -// Template: src/s32-f32-vcvt/simd.c.in -// Generator: tools/xngen -// -// Copyright 2024 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - - -#include -#include - -#include "xnnpack/simd/f32-avx2.h" -#include "xnnpack/simd/s32-avx2.h" - -#include "xnnpack/common.h" -#include "xnnpack/microparams.h" - - -void xnn_s32_f32_vcvt_ukernel__avx2_u8( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 8); - assert(xnn_simd_size_s32 == 8); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__avx2_u16( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 8); - assert(xnn_simd_size_s32 == 8); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 16 * sizeof(int32_t); batch -= 16 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - input += 2 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - output += 2 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__avx2_u24( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 8); - assert(xnn_simd_size_s32 == 8); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 24 * sizeof(int32_t); batch -= 24 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - input += 3 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - output += 3 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__avx2_u32( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 8); - assert(xnn_simd_size_s32 == 8); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 32 * sizeof(int32_t); batch -= 32 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - const xnn_simd_s32_t vx3 = xnn_loadu_s32(input + 3 * xnn_simd_size_s32); - input += 4 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - const xnn_simd_f32_t vy3 = xnn_cvt_f32_s32(xnn_sub_s32(vx3, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - xnn_storeu_f32(output + 3 * xnn_simd_size_f32, vy3); - output += 4 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} diff --git a/src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c b/src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c deleted file mode 100644 index e76cf1b43d59..000000000000 --- a/src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c +++ /dev/null @@ -1,208 +0,0 @@ -// Auto-generated file. Do not edit! -// Template: src/s32-f32-vcvt/simd.c.in -// Generator: tools/xngen -// -// Copyright 2024 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - - -#include -#include - -#include "xnnpack/simd/f32-avx512f.h" -#include "xnnpack/simd/s32-avx512f.h" - -#include "xnnpack/common.h" -#include "xnnpack/microparams.h" - - -void xnn_s32_f32_vcvt_ukernel__avx512f_u16( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 16); - assert(xnn_simd_size_s32 == 16); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__avx512f_u32( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 16); - assert(xnn_simd_size_s32 == 16); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 32 * sizeof(int32_t); batch -= 32 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - input += 2 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - output += 2 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__avx512f_u48( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 16); - assert(xnn_simd_size_s32 == 16); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 48 * sizeof(int32_t); batch -= 48 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - input += 3 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - output += 3 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__avx512f_u64( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 16); - assert(xnn_simd_size_s32 == 16); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 64 * sizeof(int32_t); batch -= 64 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - const xnn_simd_s32_t vx3 = xnn_loadu_s32(input + 3 * xnn_simd_size_s32); - input += 4 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - const xnn_simd_f32_t vy3 = xnn_cvt_f32_s32(xnn_sub_s32(vx3, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - xnn_storeu_f32(output + 3 * xnn_simd_size_f32, vy3); - output += 4 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} diff --git a/src/s32-f32-vcvt/gen/s32-f32-vcvt-neon.c b/src/s32-f32-vcvt/gen/s32-f32-vcvt-neon.c deleted file mode 100644 index e4aca41875d5..000000000000 --- a/src/s32-f32-vcvt/gen/s32-f32-vcvt-neon.c +++ /dev/null @@ -1,208 +0,0 @@ -// Auto-generated file. Do not edit! -// Template: src/s32-f32-vcvt/simd.c.in -// Generator: tools/xngen -// -// Copyright 2024 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - - -#include -#include - -#include "xnnpack/simd/f32-neon.h" -#include "xnnpack/simd/s32-neon.h" - -#include "xnnpack/common.h" -#include "xnnpack/microparams.h" - - -void xnn_s32_f32_vcvt_ukernel__neon_u4( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__neon_u8( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 8 * sizeof(int32_t); batch -= 8 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - input += 2 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - output += 2 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__neon_u12( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 12 * sizeof(int32_t); batch -= 12 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - input += 3 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - output += 3 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__neon_u16( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 16 * sizeof(int32_t); batch -= 16 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - const xnn_simd_s32_t vx3 = xnn_loadu_s32(input + 3 * xnn_simd_size_s32); - input += 4 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - const xnn_simd_f32_t vy3 = xnn_cvt_f32_s32(xnn_sub_s32(vx3, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - xnn_storeu_f32(output + 3 * xnn_simd_size_f32, vy3); - output += 4 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} diff --git a/src/s32-f32-vcvt/gen/s32-f32-vcvt-scalar.c b/src/s32-f32-vcvt/gen/s32-f32-vcvt-scalar.c deleted file mode 100644 index 4244fd0add9c..000000000000 --- a/src/s32-f32-vcvt/gen/s32-f32-vcvt-scalar.c +++ /dev/null @@ -1,208 +0,0 @@ -// Auto-generated file. Do not edit! -// Template: src/s32-f32-vcvt/simd.c.in -// Generator: tools/xngen -// -// Copyright 2024 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - - -#include -#include - -#include "xnnpack/simd/f32-scalar.h" -#include "xnnpack/simd/s32-scalar.h" - -#include "xnnpack/common.h" -#include "xnnpack/microparams.h" - - -void xnn_s32_f32_vcvt_ukernel__scalar_u1( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 1); - assert(xnn_simd_size_s32 == 1); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__scalar_u2( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 1); - assert(xnn_simd_size_s32 == 1); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 2 * sizeof(int32_t); batch -= 2 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - input += 2 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - output += 2 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__scalar_u3( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 1); - assert(xnn_simd_size_s32 == 1); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 3 * sizeof(int32_t); batch -= 3 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - input += 3 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - output += 3 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__scalar_u4( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 1); - assert(xnn_simd_size_s32 == 1); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 4 * sizeof(int32_t); batch -= 4 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - const xnn_simd_s32_t vx3 = xnn_loadu_s32(input + 3 * xnn_simd_size_s32); - input += 4 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - const xnn_simd_f32_t vy3 = xnn_cvt_f32_s32(xnn_sub_s32(vx3, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - xnn_storeu_f32(output + 3 * xnn_simd_size_f32, vy3); - output += 4 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} diff --git a/src/s32-f32-vcvt/gen/s32-f32-vcvt-wasmsimd.c b/src/s32-f32-vcvt/gen/s32-f32-vcvt-wasmsimd.c deleted file mode 100644 index 7676ccb7d531..000000000000 --- a/src/s32-f32-vcvt/gen/s32-f32-vcvt-wasmsimd.c +++ /dev/null @@ -1,208 +0,0 @@ -// Auto-generated file. Do not edit! -// Template: src/s32-f32-vcvt/simd.c.in -// Generator: tools/xngen -// -// Copyright 2024 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - - -#include -#include - -#include "xnnpack/simd/f32-wasmsimd.h" -#include "xnnpack/simd/s32-wasmsimd.h" - -#include "xnnpack/common.h" -#include "xnnpack/microparams.h" - - -void xnn_s32_f32_vcvt_ukernel__wasmsimd_u4( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__wasmsimd_u8( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 8 * sizeof(int32_t); batch -= 8 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - input += 2 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - output += 2 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__wasmsimd_u12( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 12 * sizeof(int32_t); batch -= 12 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - input += 3 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - output += 3 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} - -void xnn_s32_f32_vcvt_ukernel__wasmsimd_u16( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) -{ - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == 4); - assert(xnn_simd_size_s32 == 4); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - for (; batch >= 16 * sizeof(int32_t); batch -= 16 * sizeof(int32_t)) { - const xnn_simd_s32_t vx0 = xnn_loadu_s32(input); - const xnn_simd_s32_t vx1 = xnn_loadu_s32(input + 1 * xnn_simd_size_s32); - const xnn_simd_s32_t vx2 = xnn_loadu_s32(input + 2 * xnn_simd_size_s32); - const xnn_simd_s32_t vx3 = xnn_loadu_s32(input + 3 * xnn_simd_size_s32); - input += 4 * xnn_simd_size_s32; - - const xnn_simd_f32_t vy0 = xnn_cvt_f32_s32(xnn_sub_s32(vx0, sub)); - const xnn_simd_f32_t vy1 = xnn_cvt_f32_s32(xnn_sub_s32(vx1, sub)); - const xnn_simd_f32_t vy2 = xnn_cvt_f32_s32(xnn_sub_s32(vx2, sub)); - const xnn_simd_f32_t vy3 = xnn_cvt_f32_s32(xnn_sub_s32(vx3, sub)); - - xnn_storeu_f32(output, vy0); - xnn_storeu_f32(output + 1 * xnn_simd_size_f32, vy1); - xnn_storeu_f32(output + 2 * xnn_simd_size_f32, vy2); - xnn_storeu_f32(output + 3 * xnn_simd_size_f32, vy3); - output += 4 * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } -} diff --git a/src/s32-f32-vcvt/s32-f32-vcvt.h b/src/s32-f32-vcvt/s32-f32-vcvt.h deleted file mode 100644 index 5d41de18d3e0..000000000000 --- a/src/s32-f32-vcvt/s32-f32-vcvt.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2023 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#ifndef XNN_CVT_UKERNEL_WITH_PARAMS -#define XNN_CVT_UKERNEL_WITH_PARAMS(arch_flags, ukernel, batch_tile, vector_tile, type_in, type_out, params_type, init_params) \ - XNN_CVT_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, type_in, type_out) -#define XNN_DEFINED_CVT_UKERNEL_WITH_PARAMS -#endif - -#ifndef XNN_CVT_UKERNEL -#define XNN_CVT_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, type_in, type_out) \ - XNN_CVT_UKERNEL_WITH_PARAMS(arch_flags, ukernel, batch_tile, vector_tile, type_in, type_out, void, /*init_params=*/nullptr) -#define XNN_DEFINED_CVT_UKERNEL -#endif - -#ifndef XNN_QUANTIZED -#define XNN_QUANTIZED(T) T -#define XNN_DEFINED_QUANTIZED -#endif - - -#if XNN_ARCH_ARM || XNN_ARCH_ARM64 -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_s32_f32_vcvt_ukernel__neon_u4, 4, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_s32_f32_vcvt_ukernel__neon_u8, 8, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_s32_f32_vcvt_ukernel__neon_u12, 12, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_arm_neon, xnn_s32_f32_vcvt_ukernel__neon_u16, 16, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_s32_f32_vcvt_ukernel__avx2_u8, 8, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_s32_f32_vcvt_ukernel__avx2_u16, 16, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_s32_f32_vcvt_ukernel__avx2_u24, 24, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx2, xnn_s32_f32_vcvt_ukernel__avx2_u32, 32, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_s32_f32_vcvt_ukernel__avx512f_u16, 16, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_s32_f32_vcvt_ukernel__avx512f_u32, 32, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_s32_f32_vcvt_ukernel__avx512f_u48, 48, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(xnn_arch_x86_avx512f, xnn_s32_f32_vcvt_ukernel__avx512f_u64, 64, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - -#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__wasmsimd_u4, 4, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__wasmsimd_u8, 8, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__wasmsimd_u12, 12, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__wasmsimd_u16, 16, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD - -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__scalar_u1, 1, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__scalar_u2, 2, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__scalar_u3, 3, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) -XNN_CVT_UKERNEL_WITH_PARAMS(0, xnn_s32_f32_vcvt_ukernel__scalar_u4, 4, false, XNN_QUANTIZED(int32_t), float, struct xnn_s32_f32_cvt_params, xnn_init_s32_f32_cvt_scalar_params) - -#ifdef XNN_DEFINED_CVT_UKERNEL_WITH_PARAMS -#undef XNN_DEFINED_CVT_UKERNEL_WITH_PARAMS -#undef XNN_CVT_UKERNEL_WITH_PARAMS -#endif - -#ifdef XNN_DEFINED_CVT_UKERNEL -#undef XNN_DEFINED_CVT_UKERNEL -#undef XNN_CVT_UKERNEL -#endif - -#ifdef XNN_DEFINED_QUANTIZED -#undef XNN_DEFINED_QUANTIZED -#undef XNN_QUANTIZED -#endif - diff --git a/src/s32-f32-vcvt/simd.c.in b/src/s32-f32-vcvt/simd.c.in deleted file mode 100644 index 6c387bac4c7d..000000000000 --- a/src/s32-f32-vcvt/simd.c.in +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" -$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(",")) -$SIMD_SIZE = BATCH_TILES[0] - -#include -#include - -#include "xnnpack/simd/f32-${ARCH}.h" -#include "xnnpack/simd/s32-${ARCH}.h" - -#include "xnnpack/common.h" -#include "xnnpack/microparams.h" - -$for BATCH_TILE in BATCH_TILES: - - void xnn_s32_f32_vcvt_ukernel__${ARCH}_u${BATCH_TILE}( - size_t batch, - const int32_t* input, - float* output, - const struct xnn_s32_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) - { - assert(batch != 0); - assert(batch % sizeof(float) == 0); - assert(input != NULL); - assert(output != NULL); - assert(xnn_simd_size_f32 == ${SIMD_SIZE}); - assert(xnn_simd_size_s32 == ${SIMD_SIZE}); - - const xnn_simd_s32_t sub = xnn_set1_s32(params->scalar.zero_point); - - $SIMD_TILE = BATCH_TILE // SIMD_SIZE - $if SIMD_TILE > 1: - for (; batch >= ${BATCH_TILE} * sizeof(int32_t); batch -= ${BATCH_TILE} * sizeof(int32_t)) { - const xnn_simd_s32_t vx${ABC[0]} = xnn_loadu_s32(input); - $for N in range(1, SIMD_TILE): - const xnn_simd_s32_t vx${ABC[N]} = xnn_loadu_s32(input + ${N} * xnn_simd_size_s32); - input += ${SIMD_TILE} * xnn_simd_size_s32; - - $for N in range(SIMD_TILE): - const xnn_simd_f32_t vy${ABC[N]} = xnn_cvt_f32_s32(xnn_sub_s32(vx${ABC[N]}, sub)); - - xnn_storeu_f32(output, vy${ABC[0]}); - $for N in range(1, SIMD_TILE): - xnn_storeu_f32(output + ${N} * xnn_simd_size_f32, vy${ABC[N]}); - output += ${SIMD_TILE} * xnn_simd_size_f32; - } - - for (; batch >= xnn_simd_bytes_s32; batch -= xnn_simd_bytes_s32) { - const xnn_simd_s32_t vx = xnn_loadu_s32(input); - input += xnn_simd_size_f32; - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_storeu_f32(output, vy); - output += xnn_simd_size_f32; - } - - if (batch != 0) { - const xnn_simd_s32_t vx = - xnn_load_tail_s32(input, batch >> XNN_LOG2_SIZEOF_INT32_T); - - const xnn_simd_f32_t vy = xnn_cvt_f32_s32(xnn_sub_s32(vx, sub)); - - xnn_store_tail_f32(output, vy, batch >> XNN_LOG2_SIZEOF_FLOAT); - } - } diff --git a/src/subgraph.c b/src/subgraph.c index 0fc36e43d40a..9e4566eff2ad 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -17,6 +17,8 @@ #include "xnnpack/allocation-type.h" #include "xnnpack/allocator.h" #include "xnnpack/common.h" +#include "xnnpack/config-types.h" +#include "xnnpack/config.h" #include "xnnpack/fp16.h" #include "xnnpack/hardware-config.h" #include "xnnpack/internal.h" @@ -191,6 +193,7 @@ void xnn_value_copy( dst_value->data = src_value->data; dst_value->producer = src_value->producer; dst_value->first_consumer = src_value->first_consumer; + dst_value->all_consumers_types_same = src_value->all_consumers_types_same; dst_value->num_consumers = src_value->num_consumers; dst_value->num_nchw_compatible_consumers = src_value->num_nchw_compatible_consumers; dst_value->layout = src_value->layout; @@ -199,6 +202,8 @@ void xnn_value_copy( dst_value->fp32_id = src_value->fp32_id; dst_value->fp16_temp_data = src_value->fp16_temp_data; dst_value->fp32_data = src_value->fp32_data; + dst_value->gemm_config = src_value->gemm_config; + dst_value->squash_groups = src_value->squash_groups; } struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph) @@ -277,6 +282,10 @@ void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph) if (subgraph->values[input_id].num_consumers++ == 0) { assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID); subgraph->values[input_id].first_consumer = n; + subgraph->values[input_id].all_consumers_types_same = true; + } else { + enum xnn_node_type first_consumer_type = subgraph->nodes[subgraph->values[input_id].first_consumer].type; + subgraph->values[input_id].all_consumers_types_same &= (first_consumer_type == node->type); } } @@ -634,7 +643,7 @@ void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph) if (!update) { return; } - // Propagate the cluster leader to other nodes in the graph untill all the + // Propagate the cluster leader to other nodes in the graph until all the // nodes in the cluster is not updated while (update) { update = false; @@ -749,10 +758,24 @@ void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph) const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3]; subgraph->nodes[node->cluster_leader].num_params += num_params; - const float* data = (const float*) filter->data; size_t num_zeroes = 0; - for (size_t i = 0; i < num_params; i++) { - num_zeroes += (size_t) (data[i] == 0.0f); + switch (filter->datatype) { + case xnn_datatype_fp32: { + const float* data = (const float*)filter->data; + for (size_t i = 0; i < num_params; i++) { + num_zeroes += (size_t)(data[i] == 0.0f); + } + break; + } + case xnn_datatype_fp16: { + const xnn_float16* data = (const xnn_float16*)filter->data; + for (size_t i = 0; i < num_params; i++) { + num_zeroes += (size_t)(xnn_float16_is_zero(data[i])); + } + break; + } + default: + XNN_UNREACHABLE; } xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params); subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes; @@ -908,7 +931,11 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) } break; case xnn_node_type_fully_connected: - if (subgraph->values[node->inputs[0]].datatype == xnn_datatype_qdint8) { + if (subgraph->values[node->inputs[0]].datatype == xnn_datatype_qdint8 || + subgraph->values[node->inputs[0]].datatype == xnn_datatype_qpint8) { + // TODO(b/340399245) - Coerce any `qpint8` values back to `qdint8` for + // conversion to fp16. + subgraph->values[node->inputs[0]].datatype = xnn_datatype_qdint8; subgraph->values[node->outputs[0]].fp16_compatible = true; } else if (subgraph->values[node->inputs[0]].datatype == xnn_datatype_fp32 && @@ -1379,9 +1406,159 @@ enum xnn_status xnn_subgraph_fusion( return xnn_status_success; } +void xnn_subgraph_optimize_dynamic_quantization_ops(xnn_subgraph_t subgraph) { + enum xnn_weights_type { + xnn_weights_type_invalid = 0, + xnn_weights_type_qb4w = 1, + xnn_weights_type_qc4w = 2, + xnn_weights_type_qc8w = 4, + }; + enum xnn_consumer_type { + xnn_consumer_type_invalid = 0, + xnn_consumer_type_batch_mat_mul = 1, + xnn_consumer_type_convolution_2d = 2, + xnn_consumer_type_deconvolution = 4, + xnn_consumer_type_fully_connected = 8, + }; + for (uint32_t n = 0; n < subgraph->num_nodes; n++) { + enum xnn_consumer_type consumer_type = xnn_consumer_type_invalid; + enum xnn_weights_type weights_type = xnn_weights_type_invalid; + struct xnn_node* node = &subgraph->nodes[n]; + const uint32_t input_id = node->inputs[0]; + const uint32_t output_id = node->outputs[0]; + struct xnn_value* input = &subgraph->values[input_id]; + struct xnn_value* output = &subgraph->values[output_id]; + // Only replace nodes for which all consumer are of the same type. + if (!output->all_consumers_types_same) continue; + if (output->datatype == xnn_datatype_qdint8) { + struct xnn_node* first_consumer_node = &subgraph->nodes[output->first_consumer]; + switch (first_consumer_node->type) { + case xnn_node_type_fully_connected: + consumer_type = xnn_consumer_type_fully_connected; + break; + case xnn_node_type_convolution_2d: + consumer_type = xnn_consumer_type_convolution_2d; + break; + case xnn_node_type_deconvolution_2d: + consumer_type = xnn_consumer_type_deconvolution; + break; + case xnn_node_type_batch_matrix_multiply: + consumer_type = xnn_consumer_type_batch_mat_mul; + break; + default: + XNN_UNREACHABLE; + } + const struct xnn_value* filter = &subgraph->values[first_consumer_node->inputs[1]]; + switch (filter->datatype) { + case xnn_datatype_qbint4: + weights_type = xnn_weights_type_qb4w; + break; + case xnn_datatype_qcint4: + weights_type = xnn_weights_type_qc4w; + break; + case xnn_datatype_qcint8: + weights_type = xnn_weights_type_qc8w; + break; + default: + XNN_UNREACHABLE; + } + bool pack_activations = false; + if (input->datatype == xnn_datatype_fp32) { + // Coerce the input from `xnn_datatype_qdint8` to `xnn_datatype_qpint8` if we + // know that we're converting for a GEMM and `qp8_f32_*` kernels are + // available. + // TODO(b/340399245) - Remove xnn_init_qp8_f32_qc4w_gemm_config check once we + // have full qp8 support. + + if (consumer_type == xnn_consumer_type_fully_connected || + consumer_type == xnn_consumer_type_batch_mat_mul) { + if ((weights_type == xnn_weights_type_qc4w) && + xnn_init_qp8_f32_qc4w_gemm_config() != NULL) { + pack_activations = true; + } else if ((weights_type == xnn_weights_type_qc8w) && + xnn_init_qp8_f32_qc8w_gemm_config() != NULL) { + pack_activations = true; + } else if ((weights_type == xnn_weights_type_qb4w) && + xnn_init_qp8_f32_qb4w_gemm_config() != NULL) { + pack_activations = true; + } + } + if (pack_activations) { + xnn_log_debug("Coercing type of output ID #%" PRIu32 + " of %s operator from `%s` to `%s`.", + output_id, + xnn_node_type_to_string(xnn_node_type_convert), + xnn_datatype_to_string(output->datatype), + xnn_datatype_to_string(xnn_datatype_qpint8)); + subgraph->values[output_id].datatype = xnn_datatype_qpint8; + switch (weights_type) { + case xnn_weights_type_qb4w: + output->gemm_config = xnn_init_qp8_f32_qb4w_gemm_config(); + break; + case xnn_weights_type_qc4w: + output->gemm_config = xnn_init_qp8_f32_qc4w_gemm_config(); + break; + case xnn_weights_type_qc8w: + output->gemm_config = xnn_init_qp8_f32_qc8w_gemm_config(); + break; + default: + XNN_UNREACHABLE; + } + // To prevent issues with packing, coerce the shape of the inputs from + // `[B, M, K]` to `[B * M, K]` for the fully-connected op. + if (consumer_type == xnn_consumer_type_fully_connected) { + output->squash_groups = true; + } + } + } + + if (!pack_activations) { + const struct xnn_gemm_config *original_config = NULL; + const struct xnn_gemm_config *unsigned_config = NULL; + if (input->datatype == xnn_datatype_fp32) { + if (weights_type == xnn_weights_type_qc4w) { + original_config = xnn_init_qd8_f32_qc4w_gemm_config(); + unsigned_config = xnn_init_qdu8_f32_qc4w_gemm_config(); + } else if (weights_type == xnn_weights_type_qc8w) { + original_config = xnn_init_qd8_f32_qc8w_gemm_config(); + unsigned_config = xnn_init_qdu8_f32_qc8w_gemm_config(); + } else if (weights_type == xnn_weights_type_qb4w) { + original_config = xnn_init_qd8_f32_qb4w_gemm_config(); + unsigned_config = xnn_init_qdu8_f32_qb4w_gemm_config(); + } + } else if (input->datatype == xnn_datatype_fp16) { + if (weights_type == xnn_weights_type_qc4w) { + original_config = xnn_init_qd8_f16_qc4w_gemm_config(); + unsigned_config = xnn_init_qdu8_f16_qc4w_gemm_config(); + } else if (weights_type == xnn_weights_type_qc8w) { + original_config = xnn_init_qd8_f16_qc8w_gemm_config(); + unsigned_config = xnn_init_qdu8_f16_qc8w_gemm_config(); + } + } + bool convert_to_qu8 = false; + if (original_config && unsigned_config) { + enum xnn_arch_flags qdu8_arch = unsigned_config->arch; + enum xnn_arch_flags qd8_arch = original_config->arch; + if (qdu8_arch > qd8_arch) { + convert_to_qu8 = true; + } + } + if (convert_to_qu8) { + xnn_log_debug("Coercing type of output ID #%" PRIu32 + " of %s operator from `%s` to `%s`.", + output_id, xnn_node_type_to_string(xnn_node_type_convert), + xnn_datatype_to_string(output->datatype), + xnn_datatype_to_string(xnn_datatype_qduint8)); + subgraph->values[output_id].datatype = xnn_datatype_qduint8; + } + } + } + } +} + enum xnn_status xnn_subgraph_optimize( xnn_subgraph_t subgraph, - uint32_t flags) + uint32_t optimization_flags) { xnn_subgraph_analyze_consumers_and_producers(subgraph); @@ -1403,7 +1580,7 @@ enum xnn_status xnn_subgraph_optimize( } } - if (!(flags & XNN_FLAG_NO_OPERATOR_FUSION)) { + if (!(optimization_flags & XNN_FLAG_NO_OPERATOR_FUSION)) { xnn_subgraph_fusion(subgraph); } @@ -1413,13 +1590,13 @@ enum xnn_status xnn_subgraph_optimize( return xnn_status_unsupported_hardware; } - if ((flags & XNN_FLAG_FORCE_FP16_INFERENCE) && (!xnn_is_f16_compatible_config(hardware_config))) { + if ((optimization_flags & XNN_FLAG_FORCE_FP16_INFERENCE) && (!xnn_is_f16_compatible_config(hardware_config))) { xnn_log_error("failed to force FP16 inference: hardware supports neither native nor emulated FP16 operators"); return xnn_status_unsupported_hardware; } const bool try_native_fp16 = - (flags & XNN_FLAG_HINT_FP16_INFERENCE) && xnn_is_f16_supported_natively(hardware_config); - const bool force_fp16 = (flags & XNN_FLAG_FORCE_FP16_INFERENCE); + (optimization_flags & XNN_FLAG_HINT_FP16_INFERENCE) && xnn_is_f16_supported_natively(hardware_config); + const bool force_fp16 = (optimization_flags & XNN_FLAG_FORCE_FP16_INFERENCE); if (try_native_fp16 || force_fp16) { const bool fp16_rewrite_succeeded = xnn_subgraph_rewrite_for_fp16(subgraph); if (force_fp16 && !fp16_rewrite_succeeded) { @@ -1429,11 +1606,13 @@ enum xnn_status xnn_subgraph_optimize( } #if XNN_ENABLE_SPARSE - if ((flags & XNN_FLAG_HINT_SPARSE_INFERENCE) && (xnn_is_chw_compatible_config(hardware_config))) { + if ((optimization_flags & XNN_FLAG_HINT_SPARSE_INFERENCE) && (xnn_is_chw_compatible_config(hardware_config))) { xnn_subgraph_rewrite_for_nchw(subgraph); } #endif + xnn_subgraph_optimize_dynamic_quantization_ops(subgraph); + return xnn_status_success; } diff --git a/src/subgraph/batch-matrix-multiply.c b/src/subgraph/batch-matrix-multiply.c index ea2712ca3346..21f75ae57a8d 100644 --- a/src/subgraph/batch-matrix-multiply.c +++ b/src/subgraph/batch-matrix-multiply.c @@ -12,6 +12,7 @@ #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/common.h" +#include "xnnpack/internal.h" #include "xnnpack/log.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" @@ -43,87 +44,82 @@ static enum xnn_status create_batch_matrix_multiply_operator( const enum xnn_datatype inputa_datatype = values[input_a_id].datatype; const enum xnn_datatype inputb_datatype = values[input_b_id].datatype; + if (inputa_datatype == inputb_datatype && inputa_datatype == xnn_datatype_fp16) { + return xnn_create_batch_matrix_multiply_nc_f16(node->flags, &opdata->operator_objects[0]); + } + const struct xnn_value* input_b = values + input_b_id; + // Get the shape and size of the second input. + size_t batch_size_b = 1; + size_t k = 0; + size_t n = 0; + if (xnn_value_is_static(input_b)) { + if (input_b->shape.num_dims < 2) { + xnn_log_error( + "failed to create %s operator with input_b ID #%" PRIu32 + ": unsupported number of dimension %zu, must be at least 2", + xnn_node_type_to_string(xnn_node_type_batch_matrix_multiply), + input_b_id, input_b->shape.num_dims); + return xnn_status_invalid_parameter; + } + for (size_t i = 0; i < input_b->shape.num_dims - 2; i++) { + batch_size_b *= input_b->shape.dim[i]; + } + k = node->flags & XNN_FLAG_TRANSPOSE_B + ? input_b->shape.dim[input_b->shape.num_dims - 1] + : input_b->shape.dim[input_b->shape.num_dims - 2]; + n = node->flags & XNN_FLAG_TRANSPOSE_B + ? input_b->shape.dim[input_b->shape.num_dims - 2] + : input_b->shape.dim[input_b->shape.num_dims - 1]; + + } switch (inputa_datatype) { - case xnn_datatype_fp16: - switch (inputb_datatype) { - case xnn_datatype_fp16: - status = xnn_create_batch_matrix_multiply_nc_f16(node->flags, &opdata->operator_objects[0]); - break; - default: - XNN_UNREACHABLE; - } - break; case xnn_datatype_fp32: switch (inputb_datatype) { case xnn_datatype_fp32: { // Get the shape and size of the second input. - const uint32_t input_b_id = opdata->inputs[1]; - assert(input_b_id != XNN_INVALID_VALUE_ID); - assert(input_b_id < num_values); - const struct xnn_value* input_b = values + input_b_id; if (xnn_value_is_static(input_b)) { - if (input_b->shape.num_dims < 2) { - xnn_log_error( - "failed to create %s operator with input_b ID #%" PRIu32 - ": unsupported number of dimension %zu, must be at least 2", - xnn_node_type_to_string(xnn_node_type_batch_matrix_multiply), - input_b_id, input_b->shape.num_dims); - return xnn_status_invalid_parameter; - } - size_t batch_size_b = 1; - for (size_t i = 0; i < input_b->shape.num_dims - 2; i++) { - batch_size_b *= input_b->shape.dim[i]; - } - const size_t k = - node->flags & XNN_FLAG_TRANSPOSE_B - ? input_b->shape.dim[input_b->shape.num_dims - 1] - : input_b->shape.dim[input_b->shape.num_dims - 2]; - const size_t n = - node->flags & XNN_FLAG_TRANSPOSE_B - ? input_b->shape.dim[input_b->shape.num_dims - 2] - : input_b->shape.dim[input_b->shape.num_dims - 1]; - - status = xnn_create_batch_matrix_multiply_nc_f32_const_weights( + return xnn_create_batch_matrix_multiply_nc_f32_const_weights( batch_size_b, k, n, input_b->data, node->flags, &opdata->operator_objects[0]); } else { - status = xnn_create_batch_matrix_multiply_nc_f32( + return xnn_create_batch_matrix_multiply_nc_f32( node->flags, &opdata->operator_objects[0]); } - break; } default: XNN_UNREACHABLE; } break; case xnn_datatype_qdint8: { - // Get the shape and size of the second input. - const uint32_t input_b_id = opdata->inputs[1]; - assert(input_b_id != XNN_INVALID_VALUE_ID); - assert(input_b_id < num_values); - const struct xnn_value* input_b = values + input_b_id; - if (input_b->shape.num_dims < 2) { - xnn_log_error( - "failed to create %s operator with input_b ID #%" PRIu32 - ": unsupported number of dimension %zu, must be at least 2", - xnn_node_type_to_string(xnn_node_type_batch_matrix_multiply), - input_b_id, input_b->shape.num_dims); - return xnn_status_invalid_parameter; + switch (inputb_datatype) { + case xnn_datatype_qcint8: + status = xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( + batch_size_b, k, n, input_b->data, + input_b->quantization.channelwise_scale, node->flags, + &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; } - size_t batch_size_b = 1; - for (size_t i = 0; i < input_b->shape.num_dims - 2; i++) { - batch_size_b *= input_b->shape.dim[i]; + break; + } + case xnn_datatype_qpint8: { + switch (inputb_datatype) { + case xnn_datatype_qcint8: + status = xnn_create_batch_matrix_multiply_nc_qp8_f32_qc8w( + batch_size_b, k, n, input_b->data, + input_b->quantization.channelwise_scale, node->flags, + &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; } - const size_t k = node->flags & XNN_FLAG_TRANSPOSE_B - ? input_b->shape.dim[input_b->shape.num_dims - 1] - : input_b->shape.dim[input_b->shape.num_dims - 2]; - const size_t n = node->flags & XNN_FLAG_TRANSPOSE_B - ? input_b->shape.dim[input_b->shape.num_dims - 2] - : input_b->shape.dim[input_b->shape.num_dims - 1]; - + break; + } + case xnn_datatype_qduint8: { switch (inputb_datatype) { case xnn_datatype_qcint8: - status = xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( + status = xnn_create_batch_matrix_multiply_nc_qdu8_f32_qc8w( batch_size_b, k, n, input_b->data, input_b->quantization.channelwise_scale, node->flags, &opdata->operator_objects[0]); @@ -251,6 +247,16 @@ static enum xnn_status reshape_batch_matrix_multiply_operator( opdata->operator_objects[0], num_batch_dims, padded_dims_a, padded_dims_b, m, k, n, threadpool); break; + case xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w: + status = xnn_reshape_batch_matrix_multiply_nc_qp8_f32_qc8w( + opdata->operator_objects[0], num_batch_dims, padded_dims_a, + padded_dims_b, m, k, n, threadpool); + break; + case xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w: + status = xnn_reshape_batch_matrix_multiply_nc_qdu8_f32_qc8w( + opdata->operator_objects[0], num_batch_dims, padded_dims_a, + padded_dims_b, m, k, n, threadpool); + break; default: XNN_UNREACHABLE; } @@ -316,6 +322,13 @@ static enum xnn_status setup_batch_matrix_multiply_operator( return xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w( opdata->operator_objects[0], input_a_data, input_a->quantization.dynamic_params, output_data); + case xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w: + return xnn_setup_batch_matrix_multiply_nc_qp8_f32_qc8w( + opdata->operator_objects[0], input_a_data, output_data); + case xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w: + return xnn_setup_batch_matrix_multiply_nc_qdu8_f32_qc8w( + opdata->operator_objects[0], input_a_data, + input_a->quantization.dynamic_params, output_data); default: XNN_UNREACHABLE; } diff --git a/src/subgraph/convolution-2d.c b/src/subgraph/convolution-2d.c index 59aa28bf91b4..b87eff764b58 100644 --- a/src/subgraph/convolution-2d.c +++ b/src/subgraph/convolution-2d.c @@ -66,32 +66,74 @@ static enum xnn_status create_convolution_operator( : xnn_datatype_invalid; const enum xnn_datatype output_datatype = values[output_id].datatype; if (values[output_id].layout == xnn_layout_type_nchw) { - switch (output_datatype) { + switch (filter_datatype) { case xnn_datatype_fp16: - status = xnn_create_convolution2d_nchw_f16( - node->params.convolution_2d.input_padding_top, - node->params.convolution_2d.input_padding_right, - node->params.convolution_2d.input_padding_bottom, - node->params.convolution_2d.input_padding_left, - node->params.convolution_2d.kernel_height, - node->params.convolution_2d.kernel_width, - node->params.convolution_2d.subsampling_height, - node->params.convolution_2d.subsampling_width, - node->params.convolution_2d.dilation_height, - node->params.convolution_2d.dilation_width, - node->params.convolution_2d.groups, - node->params.convolution_2d.group_input_channels, - node->params.convolution_2d.group_output_channels, - node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */, - node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */, - filter_data, - bias_data, - node->activation.output_min, - node->activation.output_max, - node->flags | (values[input_id].layout == xnn_layout_type_nhwc ? XNN_FLAG_INPUT_NHWC : 0) | XNN_FLAG_FP32_STATIC_WEIGHTS, - code_cache, - weights_cache, - &opdata->operator_objects[0]); + switch (output_datatype) { + case xnn_datatype_fp32: { + uint32_t flags = + node->flags | (values[input_id].layout == xnn_layout_type_nhwc + ? XNN_FLAG_INPUT_NHWC + : 0); + if (bias_datatype == xnn_datatype_fp32) { + flags |= XNN_FLAG_FP32_STATIC_BIASES; + } + status = xnn_create_convolution2d_nchw_f32_f16( + node->params.convolution_2d.input_padding_top, + node->params.convolution_2d.input_padding_right, + node->params.convolution_2d.input_padding_bottom, + node->params.convolution_2d.input_padding_left, + node->params.convolution_2d.kernel_height, + node->params.convolution_2d.kernel_width, + node->params.convolution_2d.subsampling_height, + node->params.convolution_2d.subsampling_width, + node->params.convolution_2d.dilation_height, + node->params.convolution_2d.dilation_width, + node->params.convolution_2d.groups, + node->params.convolution_2d.group_input_channels, + node->params.convolution_2d.group_output_channels, + node->params.convolution_2d.group_input_channels * + node->params.convolution_2d.groups /* input_pixel_stride */, + node->params.convolution_2d.group_output_channels * + node->params.convolution_2d + .groups /* output_pixel_stride */, + filter_data, bias_data, node->activation.output_min, + node->activation.output_max, flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; + } + case xnn_datatype_fp16: + status = xnn_create_convolution2d_nchw_f16( + node->params.convolution_2d.input_padding_top, + node->params.convolution_2d.input_padding_right, + node->params.convolution_2d.input_padding_bottom, + node->params.convolution_2d.input_padding_left, + node->params.convolution_2d.kernel_height, + node->params.convolution_2d.kernel_width, + node->params.convolution_2d.subsampling_height, + node->params.convolution_2d.subsampling_width, + node->params.convolution_2d.dilation_height, + node->params.convolution_2d.dilation_width, + node->params.convolution_2d.groups, + node->params.convolution_2d.group_input_channels, + node->params.convolution_2d.group_output_channels, + node->params.convolution_2d.group_input_channels * + node->params.convolution_2d.groups /* input_pixel_stride */, + node->params.convolution_2d.group_output_channels * + node->params.convolution_2d + .groups /* output_pixel_stride */, + filter_data, bias_data, node->activation.output_min, + node->activation.output_max, + node->flags | + (values[input_id].layout == xnn_layout_type_nhwc + ? XNN_FLAG_INPUT_NHWC + : 0) | + XNN_FLAG_FP32_STATIC_WEIGHTS, + code_cache, weights_cache, &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + break; + } break; case xnn_datatype_fp32: status = xnn_create_convolution2d_nchw_f32( @@ -223,31 +265,64 @@ static enum xnn_status create_convolution_operator( } break; case xnn_datatype_qcint8: - status = xnn_create_convolution2d_nhwc_qd8_f32_qc8w( - node->params.convolution_2d.input_padding_top, - node->params.convolution_2d.input_padding_right, - node->params.convolution_2d.input_padding_bottom, - node->params.convolution_2d.input_padding_left, - node->params.convolution_2d.kernel_height, - node->params.convolution_2d.kernel_width, - node->params.convolution_2d.subsampling_height, - node->params.convolution_2d.subsampling_width, - node->params.convolution_2d.dilation_height, - node->params.convolution_2d.dilation_width, - node->params.convolution_2d.groups, - node->params.convolution_2d.group_input_channels, - node->params.convolution_2d.group_output_channels, - /*input_channel_stride=*/node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups, - /*output_channel_stride=*/node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups, - values[filter_id].quantization.channelwise_scale, - filter_data, - bias_data, - node->activation.output_min, - node->activation.output_max, - node->flags, - code_cache, - weights_cache, - &opdata->operator_objects[0]); + switch (input_datatype) { + case xnn_datatype_qdint8: + status = xnn_create_convolution2d_nhwc_qd8_f32_qc8w( + node->params.convolution_2d.input_padding_top, + node->params.convolution_2d.input_padding_right, + node->params.convolution_2d.input_padding_bottom, + node->params.convolution_2d.input_padding_left, + node->params.convolution_2d.kernel_height, + node->params.convolution_2d.kernel_width, + node->params.convolution_2d.subsampling_height, + node->params.convolution_2d.subsampling_width, + node->params.convolution_2d.dilation_height, + node->params.convolution_2d.dilation_width, + node->params.convolution_2d.groups, + node->params.convolution_2d.group_input_channels, + node->params.convolution_2d.group_output_channels, + /*input_channel_stride=*/node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups, + /*output_channel_stride=*/node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups, + values[filter_id].quantization.channelwise_scale, + filter_data, + bias_data, + node->activation.output_min, + node->activation.output_max, + node->flags, + code_cache, + weights_cache, + &opdata->operator_objects[0]); + break; + case xnn_datatype_qduint8: + status = xnn_create_convolution2d_nhwc_qdu8_f32_qc8w( + node->params.convolution_2d.input_padding_top, + node->params.convolution_2d.input_padding_right, + node->params.convolution_2d.input_padding_bottom, + node->params.convolution_2d.input_padding_left, + node->params.convolution_2d.kernel_height, + node->params.convolution_2d.kernel_width, + node->params.convolution_2d.subsampling_height, + node->params.convolution_2d.subsampling_width, + node->params.convolution_2d.dilation_height, + node->params.convolution_2d.dilation_width, + node->params.convolution_2d.groups, + node->params.convolution_2d.group_input_channels, + node->params.convolution_2d.group_output_channels, + /*input_channel_stride=*/node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups, + /*output_channel_stride=*/node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups, + values[filter_id].quantization.channelwise_scale, + filter_data, + bias_data, + node->activation.output_min, + node->activation.output_max, + node->flags, + code_cache, + weights_cache, + &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + } break; default: XNN_UNREACHABLE; @@ -289,34 +364,66 @@ static enum xnn_status create_convolution_operator( weights_cache, &opdata->operator_objects[0]); break; - case xnn_datatype_qcint8: { - status = xnn_create_convolution2d_nhwc_qd8_f16_qc8w( - node->params.convolution_2d.input_padding_top, - node->params.convolution_2d.input_padding_right, - node->params.convolution_2d.input_padding_bottom, - node->params.convolution_2d.input_padding_left, - node->params.convolution_2d.kernel_height, - node->params.convolution_2d.kernel_width, - node->params.convolution_2d.subsampling_height, - node->params.convolution_2d.subsampling_width, - node->params.convolution_2d.dilation_height, - node->params.convolution_2d.dilation_width, - node->params.convolution_2d.groups, - node->params.convolution_2d.group_input_channels, - node->params.convolution_2d.group_output_channels, - /*input_channel_stride=*/node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups, - /*output_channel_stride=*/node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups, - values[filter_id].quantization.channelwise_scale, - filter_data, - bias_data, - node->activation.output_min, - node->activation.output_max, - node->flags, - code_cache, - weights_cache, - &opdata->operator_objects[0]); + case xnn_datatype_qcint8: + switch (input_datatype) { + case xnn_datatype_qdint8: + status = xnn_create_convolution2d_nhwc_qd8_f16_qc8w( + node->params.convolution_2d.input_padding_top, + node->params.convolution_2d.input_padding_right, + node->params.convolution_2d.input_padding_bottom, + node->params.convolution_2d.input_padding_left, + node->params.convolution_2d.kernel_height, + node->params.convolution_2d.kernel_width, + node->params.convolution_2d.subsampling_height, + node->params.convolution_2d.subsampling_width, + node->params.convolution_2d.dilation_height, + node->params.convolution_2d.dilation_width, + node->params.convolution_2d.groups, + node->params.convolution_2d.group_input_channels, + node->params.convolution_2d.group_output_channels, + /*input_channel_stride=*/node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups, + /*output_channel_stride=*/node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups, + values[filter_id].quantization.channelwise_scale, + filter_data, + bias_data, + node->activation.output_min, + node->activation.output_max, + node->flags, + code_cache, + weights_cache, + &opdata->operator_objects[0]); + break; + case xnn_datatype_qduint8: + status = xnn_create_convolution2d_nhwc_qdu8_f16_qc8w( + node->params.convolution_2d.input_padding_top, + node->params.convolution_2d.input_padding_right, + node->params.convolution_2d.input_padding_bottom, + node->params.convolution_2d.input_padding_left, + node->params.convolution_2d.kernel_height, + node->params.convolution_2d.kernel_width, + node->params.convolution_2d.subsampling_height, + node->params.convolution_2d.subsampling_width, + node->params.convolution_2d.dilation_height, + node->params.convolution_2d.dilation_width, + node->params.convolution_2d.groups, + node->params.convolution_2d.group_input_channels, + node->params.convolution_2d.group_output_channels, + /*input_channel_stride=*/node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups, + /*output_channel_stride=*/node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups, + values[filter_id].quantization.channelwise_scale, + filter_data, + bias_data, + node->activation.output_min, + node->activation.output_max, + node->flags, + code_cache, + weights_cache, + &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + } break; - } default: XNN_UNREACHABLE; } @@ -514,6 +621,18 @@ static enum xnn_status reshape_convolution_operator( &output_width, threadpool); break; + case xnn_operator_type_convolution_nhwc_qdu8_f16_qc8w: + status = xnn_reshape_convolution2d_nhwc_qdu8_f16_qc8w( + opdata->operator_objects[0], + batch_size, + input_height, + input_width, + &opdata->workspace_size, + &opdata->workspace_alignment, + &output_height, + &output_width, + threadpool); + break; case xnn_operator_type_convolution_nhwc_qd8_f32_qc8w: status = xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( opdata->operator_objects[0], @@ -526,6 +645,18 @@ static enum xnn_status reshape_convolution_operator( &output_width, threadpool); break; + case xnn_operator_type_convolution_nhwc_qdu8_f32_qc8w: + status = xnn_reshape_convolution2d_nhwc_qdu8_f32_qc8w( + opdata->operator_objects[0], + batch_size, + input_height, + input_width, + &opdata->workspace_size, + &opdata->workspace_alignment, + &output_height, + &output_width, + threadpool); + break; case xnn_operator_type_convolution_nhwc_qc8: status = xnn_reshape_convolution2d_nhwc_qs8_qc8w( opdata->operator_objects[0], @@ -654,6 +785,18 @@ static enum xnn_status setup_convolution_operator( quantization_params); } break; + case xnn_operator_type_convolution_nhwc_qdu8_f16_qc8w: + { + const void* quantization_params = input_value->quantization.dynamic_params; + assert(quantization_params != NULL); + return xnn_setup_convolution2d_nhwc_qdu8_f16_qc8w( + opdata->operator_objects[0], + opdata->workspace, + input_data, + output_data, + quantization_params); + } + break; case xnn_operator_type_convolution_nhwc_qd8_f32_qc8w: { const void* quantization_params = input_value->quantization.dynamic_params; @@ -666,6 +809,18 @@ static enum xnn_status setup_convolution_operator( quantization_params); } break; + case xnn_operator_type_convolution_nhwc_qdu8_f32_qc8w: + { + const void* quantization_params = input_value->quantization.dynamic_params; + assert(quantization_params != NULL); + return xnn_setup_convolution2d_nhwc_qdu8_f32_qc8w( + opdata->operator_objects[0], + opdata->workspace, + input_data, + output_data, + quantization_params); + } + break; case xnn_operator_type_convolution_nhwc_qs8: return xnn_setup_convolution2d_nhwc_qs8( opdata->operator_objects[0], diff --git a/src/subgraph/deconvolution-2d.c b/src/subgraph/deconvolution-2d.c index 2bb7c8f65359..378b696b7106 100644 --- a/src/subgraph/deconvolution-2d.c +++ b/src/subgraph/deconvolution-2d.c @@ -12,6 +12,7 @@ #include "xnnpack/common.h" #include "xnnpack/log.h" #include "xnnpack/node-type.h" +#include "xnnpack/internal.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" #include "xnnpack/requantization.h" @@ -58,6 +59,7 @@ static enum xnn_status create_deconvolution_operator( assert(filter_data != NULL); enum xnn_status status = xnn_status_uninitialized; + const enum xnn_datatype input_datatype = values[input_id].datatype; const enum xnn_datatype filter_datatype = values[filter_id].datatype; const enum xnn_datatype bias_datatype = bias_id != XNN_INVALID_VALUE_ID ? values[filter_id].datatype @@ -153,31 +155,64 @@ static enum xnn_status create_deconvolution_operator( &opdata->operator_objects[0]); break; case xnn_datatype_qcint8: - status = xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( - node->params.deconvolution_2d.padding_top, - node->params.deconvolution_2d.padding_right, - node->params.deconvolution_2d.padding_bottom, - node->params.deconvolution_2d.padding_left, - node->params.deconvolution_2d.kernel_height, - node->params.deconvolution_2d.kernel_width, - node->params.deconvolution_2d.upsampling_height, - node->params.deconvolution_2d.upsampling_width, - node->params.deconvolution_2d.dilation_height, - node->params.deconvolution_2d.dilation_width, - node->params.deconvolution_2d.groups, - node->params.deconvolution_2d.group_input_channels, - node->params.deconvolution_2d.group_output_channels, - node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */, - node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */, - values[filter_id].quantization.channelwise_scale, - filter_data, - bias_data, - node->activation.output_min, - node->activation.output_max, - node->flags, - code_cache, - weights_cache, - &opdata->operator_objects[0]); + switch (input_datatype) { + case xnn_datatype_qdint8: + status = xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( + node->params.deconvolution_2d.padding_top, + node->params.deconvolution_2d.padding_right, + node->params.deconvolution_2d.padding_bottom, + node->params.deconvolution_2d.padding_left, + node->params.deconvolution_2d.kernel_height, + node->params.deconvolution_2d.kernel_width, + node->params.deconvolution_2d.upsampling_height, + node->params.deconvolution_2d.upsampling_width, + node->params.deconvolution_2d.dilation_height, + node->params.deconvolution_2d.dilation_width, + node->params.deconvolution_2d.groups, + node->params.deconvolution_2d.group_input_channels, + node->params.deconvolution_2d.group_output_channels, + node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */, + node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */, + values[filter_id].quantization.channelwise_scale, + filter_data, + bias_data, + node->activation.output_min, + node->activation.output_max, + node->flags, + code_cache, + weights_cache, + &opdata->operator_objects[0]); + break; + case xnn_datatype_qduint8: + status = xnn_create_deconvolution2d_nhwc_qdu8_f32_qc8w( + node->params.deconvolution_2d.padding_top, + node->params.deconvolution_2d.padding_right, + node->params.deconvolution_2d.padding_bottom, + node->params.deconvolution_2d.padding_left, + node->params.deconvolution_2d.kernel_height, + node->params.deconvolution_2d.kernel_width, + node->params.deconvolution_2d.upsampling_height, + node->params.deconvolution_2d.upsampling_width, + node->params.deconvolution_2d.dilation_height, + node->params.deconvolution_2d.dilation_width, + node->params.deconvolution_2d.groups, + node->params.deconvolution_2d.group_input_channels, + node->params.deconvolution_2d.group_output_channels, + node->params.deconvolution_2d.group_input_channels * node->params.deconvolution_2d.groups /* input_pixel_stride */, + node->params.deconvolution_2d.group_output_channels * node->params.deconvolution_2d.groups /* output_pixel_stride */, + values[filter_id].quantization.channelwise_scale, + filter_data, + bias_data, + node->activation.output_min, + node->activation.output_max, + node->flags, + code_cache, + weights_cache, + &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + } break; default: XNN_UNREACHABLE; @@ -395,6 +430,18 @@ static enum xnn_status reshape_deconvolution_operator( &output_width, threadpool); break; + case xnn_operator_type_deconvolution_nhwc_qdu8_f32_qc8w: + status = xnn_reshape_deconvolution2d_nhwc_qdu8_f32_qc8w( + opdata->operator_objects[0], + batch_size, + input_height, + input_width, + opdata->adjustment_height, + opdata->adjustment_width, + &output_height, + &output_width, + threadpool); + break; default: XNN_UNREACHABLE; } @@ -483,6 +530,17 @@ static enum xnn_status setup_deconvolution_operator( quantization_params); } break; + case xnn_operator_type_deconvolution_nhwc_qdu8_f32_qc8w: + { + const void* quantization_params = input_value->quantization.dynamic_params; + assert(quantization_params != NULL); + return xnn_setup_deconvolution2d_nhwc_qdu8_f32_qc8w( + opdata->operator_objects[0], + input_data, + output_data, + quantization_params); + } + break; default: XNN_UNREACHABLE; } diff --git a/src/subgraph/depthwise-convolution-2d.c b/src/subgraph/depthwise-convolution-2d.c index fff7f186188e..02dc36f0eba0 100644 --- a/src/subgraph/depthwise-convolution-2d.c +++ b/src/subgraph/depthwise-convolution-2d.c @@ -10,6 +10,7 @@ #include "xnnpack.h" #include "xnnpack/common.h" +#include "xnnpack/internal.h" #include "xnnpack/log.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-type.h" @@ -45,8 +46,9 @@ static enum xnn_status create_convolution_operator( assert(filter_data != NULL); const void* bias_data = NULL; + uint32_t bias_id = XNN_INVALID_VALUE_ID; if (node->num_inputs > 2) { - const uint32_t bias_id = node->inputs[2]; + bias_id = node->inputs[2]; assert(bias_id != XNN_INVALID_VALUE_ID); assert(bias_id < num_values); @@ -57,6 +59,9 @@ static enum xnn_status create_convolution_operator( enum xnn_status status; const enum xnn_datatype filter_datatype = values[filter_id].datatype; const enum xnn_datatype output_datatype = values[output_id].datatype; + const enum xnn_datatype bias_datatype = bias_id != XNN_INVALID_VALUE_ID + ? values[filter_id].datatype + : xnn_datatype_invalid; if (values[output_id].layout == xnn_layout_type_nchw) { assert(values[input_id].layout == xnn_layout_type_nchw); switch (filter_datatype) { @@ -87,30 +92,69 @@ static enum xnn_status create_convolution_operator( &opdata->operator_objects[0]); break; case xnn_datatype_fp16: - status = xnn_create_convolution2d_nchw_f16( - node->params.depthwise_convolution_2d.input_padding_top, - node->params.depthwise_convolution_2d.input_padding_right, - node->params.depthwise_convolution_2d.input_padding_bottom, - node->params.depthwise_convolution_2d.input_padding_left, - node->params.depthwise_convolution_2d.kernel_height, - node->params.depthwise_convolution_2d.kernel_width, - node->params.depthwise_convolution_2d.subsampling_height, - node->params.depthwise_convolution_2d.subsampling_width, - node->params.depthwise_convolution_2d.dilation_height, - node->params.depthwise_convolution_2d.dilation_width, - node->params.depthwise_convolution_2d.input_channels /* groups */, - 1 /* group_input_channels */, - node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */, - node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */, - node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */, - filter_data, - bias_data, - node->activation.output_min, - node->activation.output_max, - node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION, - code_cache, - weights_cache, - &opdata->operator_objects[0]); + switch (output_datatype) { + case xnn_datatype_fp32: { + uint32_t flags = node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION; + if (bias_datatype == xnn_datatype_fp32) { + flags |= XNN_FLAG_FP32_STATIC_BIASES; + } + status = xnn_create_convolution2d_nchw_f32_f16( + node->params.depthwise_convolution_2d.input_padding_top, + node->params.depthwise_convolution_2d.input_padding_right, + node->params.depthwise_convolution_2d.input_padding_bottom, + node->params.depthwise_convolution_2d.input_padding_left, + node->params.depthwise_convolution_2d.kernel_height, + node->params.depthwise_convolution_2d.kernel_width, + node->params.depthwise_convolution_2d.subsampling_height, + node->params.depthwise_convolution_2d.subsampling_width, + node->params.depthwise_convolution_2d.dilation_height, + node->params.depthwise_convolution_2d.dilation_width, + node->params.depthwise_convolution_2d + .input_channels /* groups */, + 1 /* group_input_channels */, + node->params.depthwise_convolution_2d + .depth_multiplier /* group_output_channels */, + node->params.depthwise_convolution_2d + .input_channels /* input_channel_stride */, + node->params.depthwise_convolution_2d.input_channels * + node->params.depthwise_convolution_2d + .depth_multiplier /* output_channel_stride */, + filter_data, bias_data, node->activation.output_min, + node->activation.output_max, flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; + } + case xnn_datatype_fp16: + status = xnn_create_convolution2d_nchw_f16( + node->params.depthwise_convolution_2d.input_padding_top, + node->params.depthwise_convolution_2d.input_padding_right, + node->params.depthwise_convolution_2d.input_padding_bottom, + node->params.depthwise_convolution_2d.input_padding_left, + node->params.depthwise_convolution_2d.kernel_height, + node->params.depthwise_convolution_2d.kernel_width, + node->params.depthwise_convolution_2d.subsampling_height, + node->params.depthwise_convolution_2d.subsampling_width, + node->params.depthwise_convolution_2d.dilation_height, + node->params.depthwise_convolution_2d.dilation_width, + node->params.depthwise_convolution_2d + .input_channels /* groups */, + 1 /* group_input_channels */, + node->params.depthwise_convolution_2d + .depth_multiplier /* group_output_channels */, + node->params.depthwise_convolution_2d + .input_channels /* input_channel_stride */, + node->params.depthwise_convolution_2d.input_channels * + node->params.depthwise_convolution_2d + .depth_multiplier /* output_channel_stride */, + filter_data, bias_data, node->activation.output_min, + node->activation.output_max, + node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION, code_cache, + weights_cache, &opdata->operator_objects[0]); + break; + default: + XNN_UNREACHABLE; + break; + } break; default: XNN_UNREACHABLE; diff --git a/src/subgraph/fully-connected-sparse.c b/src/subgraph/fully-connected-sparse.c index 912f2d14244f..ae7e4b00bb60 100644 --- a/src/subgraph/fully-connected-sparse.c +++ b/src/subgraph/fully-connected-sparse.c @@ -10,6 +10,7 @@ #include "xnnpack.h" #include "xnnpack/common.h" +#include "xnnpack/internal.h" #include "xnnpack/log.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-type.h" @@ -44,8 +45,9 @@ static enum xnn_status create_fully_connected_operator( assert(kernel_data != NULL); const void* bias_data = NULL; + uint32_t bias_id = XNN_INVALID_VALUE_ID; if (node->num_inputs > 2) { - const uint32_t bias_id = node->inputs[2]; + bias_id = node->inputs[2]; assert(bias_id != XNN_INVALID_VALUE_ID); assert(bias_id < num_values); @@ -55,6 +57,10 @@ static enum xnn_status create_fully_connected_operator( enum xnn_status status; enum xnn_datatype input_datatype = values[input_id].datatype; + const enum xnn_datatype filter_datatype = values[filter_id].datatype; + const enum xnn_datatype bias_datatype = bias_id != XNN_INVALID_VALUE_ID + ? values[filter_id].datatype + : xnn_datatype_invalid; switch (input_datatype) { case xnn_datatype_fp16: { @@ -85,34 +91,57 @@ static enum xnn_status create_fully_connected_operator( break; } case xnn_datatype_fp32: - { - assert(values[filter_id].datatype == xnn_datatype_fp32); - status = xnn_create_convolution2d_nchw_f32( - /*input_padding_top=*/0, - /*input_padding_right=*/0, - /*input_padding_bottom=*/0, - /*input_padding_left=*/0, - /*kernel_height=*/1, - /*kernel_width=*/1, - /*subsampling_height=*/1, - /*subsampling_width=*/1, - /*dilation_height=*/1, - /*dilation_width=*/1, - /*groups=*/1, - /*group_input_channels=*/input_channels, - /*group_output_channels=*/output_channels, - /*input_channel_stride=*/input_channels, - /*output_channel_stride=*/output_channels, - kernel_data, - bias_data, - node->activation.output_min, - node->activation.output_max, - node->flags, - code_cache, - weights_cache, - &opdata->operator_objects[0]); - break; - } + switch (filter_datatype) { + case xnn_datatype_fp32: + status = xnn_create_convolution2d_nchw_f32( + /*input_padding_top=*/0, + /*input_padding_right=*/0, + /*input_padding_bottom=*/0, + /*input_padding_left=*/0, + /*kernel_height=*/1, + /*kernel_width=*/1, + /*subsampling_height=*/1, + /*subsampling_width=*/1, + /*dilation_height=*/1, + /*dilation_width=*/1, + /*groups=*/1, + /*group_input_channels=*/input_channels, + /*group_output_channels=*/output_channels, + /*input_channel_stride=*/input_channels, + /*output_channel_stride=*/output_channels, kernel_data, bias_data, + node->activation.output_min, node->activation.output_max, + node->flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; + case xnn_datatype_fp16: { + uint32_t flags = node->flags; + if (bias_datatype == xnn_datatype_fp32) { + flags |= XNN_FLAG_FP32_STATIC_BIASES; + } + status = xnn_create_convolution2d_nchw_f32_f16( + /*input_padding_top=*/0, + /*input_padding_right=*/0, + /*input_padding_bottom=*/0, + /*input_padding_left=*/0, + /*kernel_height=*/1, + /*kernel_width=*/1, + /*subsampling_height=*/1, + /*subsampling_width=*/1, + /*dilation_height=*/1, + /*dilation_width=*/1, + /*groups=*/1, + /*group_input_channels=*/input_channels, + /*group_output_channels=*/output_channels, + /*input_channel_stride=*/input_channels, + /*output_channel_stride=*/output_channels, kernel_data, bias_data, + node->activation.output_min, node->activation.output_max, + flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; + } + default: + XNN_UNREACHABLE; + } default: XNN_UNREACHABLE; } diff --git a/src/subgraph/fully-connected.c b/src/subgraph/fully-connected.c index 008edaea42b3..6ef0f5583caa 100644 --- a/src/subgraph/fully-connected.c +++ b/src/subgraph/fully-connected.c @@ -18,7 +18,6 @@ #include "xnnpack/node-type.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" -#include "xnnpack/subgraph.h" #include "xnnpack/requantization.h" #include "xnnpack/subgraph-validation.h" #include "xnnpack/subgraph.h" @@ -48,6 +47,12 @@ enum fully_connected_op_type { fc_type_qp8_f32_qb4w = 19, fc_type_pf32_f32_f32 = 20, fc_type_f32_f16_f32 = 21, + fc_type_qdu8_f16_qc8w = 22, + fc_type_qdu8_f32_qc8w = 23, + fc_type_qdu8_f32_qc4w = 24, + fc_type_qdu8_f32_qb4w = 26, + fc_type_qdu8_f16_qc4w = 27, + fc_type_qp8_f32_qc8w = 28, }; enum fully_connected_op_type get_fully_connected_op_type( @@ -83,11 +88,27 @@ enum fully_connected_op_type get_fully_connected_op_type( return fc_type_f16_f32_f16; } case xnn_datatype_qcint4: - return fc_type_qd8_f16_qc4w; + switch (input_datatype) { + case xnn_datatype_qdint8: + return fc_type_qd8_f16_qc4w; + case xnn_datatype_qduint8: + return fc_type_qdu8_f16_qc4w; + default: + XNN_UNREACHABLE; + } + break; case xnn_datatype_qbint4: return fc_type_qd8_f16_qb4w; case xnn_datatype_qcint8: - return fc_type_qd8_f16_qc8w; + switch (input_datatype) { + case xnn_datatype_qdint8: + return fc_type_qd8_f16_qc8w; + case xnn_datatype_qduint8: + return fc_type_qdu8_f16_qc8w; + default: + XNN_UNREACHABLE; + } + break; default: XNN_UNREACHABLE; } @@ -118,6 +139,8 @@ enum fully_connected_op_type get_fully_connected_op_type( switch (input_datatype) { case xnn_datatype_qdint8: return fc_type_qd8_f32_qb4w; + case xnn_datatype_qduint8: + return fc_type_qdu8_f32_qb4w; case xnn_datatype_qpint8: return fc_type_qp8_f32_qb4w; default: @@ -129,6 +152,8 @@ enum fully_connected_op_type get_fully_connected_op_type( return fc_type_f32_f32_qc4w; case xnn_datatype_qdint8: return fc_type_qd8_f32_qc4w; + case xnn_datatype_qduint8: + return fc_type_qdu8_f32_qc4w; case xnn_datatype_qpint8: return fc_type_qp8_f32_qc4w; default: @@ -141,6 +166,10 @@ enum fully_connected_op_type get_fully_connected_op_type( return fc_type_f32_f32_qc8w; case xnn_datatype_qdint8: return fc_type_qd8_f32_qc8w; + case xnn_datatype_qpint8: + return fc_type_qp8_f32_qc8w; + case xnn_datatype_qduint8: + return fc_type_qdu8_f32_qc8w; default: XNN_UNREACHABLE; } @@ -253,6 +282,16 @@ static enum xnn_status create_fully_connected_operator( bias_data, node->activation.output_min, node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qdu8_f16_qc4w: + status = xnn_create_fully_connected_nc_qdu8_f16_qc4w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + /*kernel_zero_point=*/values[filter_id].quantization.zero_point, + values[filter_id].quantization.channelwise_scale, kernel_data, + bias_data, node->activation.output_min, node->activation.output_max, + node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); + break; case fc_type_qd8_f16_qb4w: status = xnn_create_fully_connected_nc_qd8_f16_qb4w( input_channels, output_channels, @@ -274,6 +313,15 @@ static enum xnn_status create_fully_connected_operator( bias_data, node->activation.output_min, node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qdu8_f16_qc8w: + status = xnn_create_fully_connected_nc_qdu8_f16_qc8w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + values[filter_id].quantization.channelwise_scale, kernel_data, + bias_data, node->activation.output_min, node->activation.output_max, + node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); + break; case fc_type_f32_f32_f32_dynamic: status = xnn_create_dynamic_fully_connected_nc_f32( node->activation.output_min, node->activation.output_max, @@ -309,6 +357,18 @@ static enum xnn_status create_fully_connected_operator( node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qdu8_f32_qb4w: + status = xnn_create_fully_connected_nc_qdu8_f32_qb4w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + /*block_size=*/values[filter_id].quantization.block_size, + /*kernel_zero_point=*/values[filter_id].quantization.zero_point, + (const uint16_t*)values[filter_id].quantization.blockwise_scale, + kernel_data, bias_data, node->activation.output_min, + node->activation.output_max, node->flags, code_cache, weights_cache, + &opdata->operator_objects[0]); + break; case fc_type_f32_f16_f32: { uint32_t flags = node->flags; if (bias_value != NULL && bias_value->datatype == xnn_datatype_fp32) { @@ -355,6 +415,16 @@ static enum xnn_status create_fully_connected_operator( bias_data, node->activation.output_min, node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qdu8_f32_qc4w: + status = xnn_create_fully_connected_nc_qdu8_f32_qc4w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + /*kernel_zero_point=*/values[filter_id].quantization.zero_point, + values[filter_id].quantization.channelwise_scale, kernel_data, + bias_data, node->activation.output_min, node->activation.output_max, + node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); + break; case fc_type_qp8_f32_qc4w: status = xnn_create_fully_connected_nc_qp8_f32_qc4w( input_channels, output_channels, @@ -365,6 +435,15 @@ static enum xnn_status create_fully_connected_operator( bias_data, node->activation.output_min, node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qp8_f32_qc8w: + status = xnn_create_fully_connected_nc_qp8_f32_qc8w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + values[filter_id].quantization.channelwise_scale, kernel_data, + bias_data, node->activation.output_min, node->activation.output_max, + node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); + break; case fc_type_f32_f32_qc8w: status = xnn_create_fully_connected_nc_f32_qc8w( input_channels, output_channels, @@ -384,6 +463,15 @@ static enum xnn_status create_fully_connected_operator( bias_data, node->activation.output_min, node->activation.output_max, node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); break; + case fc_type_qdu8_f32_qc8w: + status = xnn_create_fully_connected_nc_qdu8_f32_qc8w( + input_channels, output_channels, + /*input_stride=*/input_channels, + /*output_stride=*/output_channels, + values[filter_id].quantization.channelwise_scale, kernel_data, + bias_data, node->activation.output_min, node->activation.output_max, + node->flags, code_cache, weights_cache, &opdata->operator_objects[0]); + break; case fc_type_qs8_qs8_qc8w: assert(!has_non_static_weights); assert(kernel_data != NULL); @@ -559,10 +647,18 @@ static enum xnn_status reshape_fully_connected_operator( status = xnn_reshape_fully_connected_nc_qd8_f32_qc4w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qdu8_f32_qc4w: + status = xnn_reshape_fully_connected_nc_qdu8_f32_qc4w( + opdata->operator_objects[0], batch_size, threadpool); + break; case xnn_operator_type_fully_connected_nc_qd8_f16_qc4w: status = xnn_reshape_fully_connected_nc_qd8_f16_qc4w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qdu8_f16_qc4w: + status = xnn_reshape_fully_connected_nc_qdu8_f16_qc4w( + opdata->operator_objects[0], batch_size, threadpool); + break; case xnn_operator_type_fully_connected_nc_qd8_f16_qb4w: status = xnn_reshape_fully_connected_nc_qd8_f16_qb4w( opdata->operator_objects[0], batch_size, threadpool); @@ -571,18 +667,34 @@ static enum xnn_status reshape_fully_connected_operator( status = xnn_reshape_fully_connected_nc_qd8_f32_qb4w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qdu8_f32_qb4w: + status = xnn_reshape_fully_connected_nc_qdu8_f32_qb4w( + opdata->operator_objects[0], batch_size, threadpool); + break; case xnn_operator_type_fully_connected_nc_qd8_f16_qc8w: status = xnn_reshape_fully_connected_nc_qd8_f16_qc8w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qdu8_f16_qc8w: + status = xnn_reshape_fully_connected_nc_qdu8_f16_qc8w( + opdata->operator_objects[0], batch_size, threadpool); + break; case xnn_operator_type_fully_connected_nc_qd8_f32_qc8w: status = xnn_reshape_fully_connected_nc_qd8_f32_qc8w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qdu8_f32_qc8w: + status = xnn_reshape_fully_connected_nc_qdu8_f32_qc8w( + opdata->operator_objects[0], batch_size, threadpool); + break; case xnn_operator_type_fully_connected_nc_qp8_f32_qc4w: status = xnn_reshape_fully_connected_nc_qp8_f32_qc4w( opdata->operator_objects[0], batch_size, threadpool); break; + case xnn_operator_type_fully_connected_nc_qp8_f32_qc8w: + status = xnn_reshape_fully_connected_nc_qp8_f32_qc8w( + opdata->operator_objects[0], batch_size, threadpool); + break; case xnn_operator_type_fully_connected_nc_qp8_f32_qb4w: status = xnn_reshape_fully_connected_nc_qp8_f32_qb4w( opdata->operator_objects[0], @@ -699,6 +811,16 @@ static enum xnn_status setup_fully_connected_operator( opdata->operator_objects[0], input_data, output_data, quantization_params); } + case xnn_operator_type_fully_connected_nc_qdu8_f32_qc4w: { + const void* quantization_params = + input_value->quantization.dynamic_params; + assert(kernel_data == NULL); + assert(bias_data == NULL); + assert(quantization_params != NULL); + return xnn_setup_fully_connected_nc_qdu8_f32_qc4w( + opdata->operator_objects[0], input_data, output_data, + quantization_params); + } case xnn_operator_type_fully_connected_nc_qd8_f16_qc4w: { const void* quantization_params = input_value->quantization.dynamic_params; @@ -709,6 +831,16 @@ static enum xnn_status setup_fully_connected_operator( opdata->operator_objects[0], input_data, output_data, quantization_params); } + case xnn_operator_type_fully_connected_nc_qdu8_f16_qc4w: { + const void* quantization_params = + input_value->quantization.dynamic_params; + assert(kernel_data == NULL); + assert(bias_data == NULL); + assert(quantization_params != NULL); + return xnn_setup_fully_connected_nc_qdu8_f16_qc4w( + opdata->operator_objects[0], input_data, output_data, + quantization_params); + } case xnn_operator_type_fully_connected_nc_qd8_f32_qb4w: { const void* quantization_params = input_value->quantization.dynamic_params; @@ -719,6 +851,16 @@ static enum xnn_status setup_fully_connected_operator( opdata->operator_objects[0], input_data, output_data, quantization_params); } + case xnn_operator_type_fully_connected_nc_qdu8_f32_qb4w: { + const void* quantization_params = + input_value->quantization.dynamic_params; + assert(kernel_data == NULL); + assert(bias_data == NULL); + assert(quantization_params != NULL); + return xnn_setup_fully_connected_nc_qdu8_f32_qb4w( + opdata->operator_objects[0], input_data, output_data, + quantization_params); + } case xnn_operator_type_fully_connected_nc_qd8_f16_qb4w: { const void* quantization_params = input_value->quantization.dynamic_params; @@ -739,6 +881,16 @@ static enum xnn_status setup_fully_connected_operator( opdata->operator_objects[0], input_data, output_data, quantization_params); } + case xnn_operator_type_fully_connected_nc_qdu8_f16_qc8w: { + const void* quantization_params = + input_value->quantization.dynamic_params; + assert(kernel_data == NULL); + assert(bias_data == NULL); + assert(quantization_params != NULL); + return xnn_setup_fully_connected_nc_qdu8_f16_qc8w( + opdata->operator_objects[0], input_data, output_data, + quantization_params); + } case xnn_operator_type_fully_connected_nc_qd8_f32_qc8w: { const void* quantization_params = input_value->quantization.dynamic_params; @@ -749,12 +901,28 @@ static enum xnn_status setup_fully_connected_operator( opdata->operator_objects[0], input_data, output_data, quantization_params); } + case xnn_operator_type_fully_connected_nc_qdu8_f32_qc8w: { + const void* quantization_params = + input_value->quantization.dynamic_params; + assert(kernel_data == NULL); + assert(bias_data == NULL); + assert(quantization_params != NULL); + return xnn_setup_fully_connected_nc_qdu8_f32_qc8w( + opdata->operator_objects[0], input_data, output_data, + quantization_params); + } case xnn_operator_type_fully_connected_nc_qp8_f32_qc4w: { assert(kernel_data == NULL); assert(bias_data == NULL); return xnn_setup_fully_connected_nc_qp8_f32_qc4w( opdata->operator_objects[0], input_data, output_data); } + case xnn_operator_type_fully_connected_nc_qp8_f32_qc8w: { + assert(kernel_data == NULL); + assert(bias_data == NULL); + return xnn_setup_fully_connected_nc_qp8_f32_qc8w( + opdata->operator_objects[0], input_data, output_data); + } case xnn_operator_type_fully_connected_nc_qp8_f32_qb4w: { assert(kernel_data == NULL); diff --git a/src/subgraph/reshape-helpers.c b/src/subgraph/reshape-helpers.c index 1dda567fe37f..383e547d8d02 100644 --- a/src/subgraph/reshape-helpers.c +++ b/src/subgraph/reshape-helpers.c @@ -31,7 +31,7 @@ enum xnn_status resize_unary_elementwise_output_tensor( const size_t new_size = xnn_tensor_get_size(output); if (new_size > output->size || opdata->workspace_size > old_workspace_size) { output->size = new_size; - if (output->datatype == xnn_datatype_qdint8) { + if (output->datatype == xnn_datatype_qdint8 || output->datatype == xnn_datatype_qduint8) { // reallocation will use this to adjust memory needed for dynamic quant params output->quantization.dynamic_params_size = xnn_tensor_get_dynamic_quant_param_size(output); } diff --git a/src/subgraph/unary.c b/src/subgraph/unary.c index aff9623a810b..d6e2da0e8b39 100644 --- a/src/subgraph/unary.c +++ b/src/subgraph/unary.c @@ -9,11 +9,8 @@ #include #include "xnnpack.h" -#include "xnnpack/common.h" -#include "xnnpack/config.h" #include "xnnpack/internal.h" #include "xnnpack/log.h" -#include "xnnpack/microparams.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator-utils.h" @@ -54,8 +51,14 @@ static enum xnn_status create_convert_operator( node->flags, &opdata->operator_objects[0]); break; + case xnn_datatype_qduint8: + status = xnn_create_convert_nc_f32_qdu8( + node->flags, + &opdata->operator_objects[0]); + break; case xnn_datatype_qpint8: status = xnn_create_convert_nc_f32_qp8(node->flags, + output_value->gemm_config, &opdata->operator_objects[0]); break; default: @@ -69,6 +72,11 @@ static enum xnn_status create_convert_operator( node->flags, &opdata->operator_objects[0]); break; + case xnn_datatype_qduint8: + status = xnn_create_convert_nc_f16_qdu8( + node->flags, + &opdata->operator_objects[0]); + break; default: break; } @@ -100,15 +108,15 @@ static enum xnn_status reshape_convert_operator( const size_t old_workspace_size = opdata->workspace_size; enum xnn_status status = xnn_status_invalid_state; + const uint32_t output_id = opdata->outputs[0]; + assert(output_id < num_values); + const struct xnn_value* output_value = values + output_id; + // Channel stride depends on number of non batch dims. + size_t num_nonbatch_dims = output_value->quantization.num_nonbatch_dims; + size_t dq_batch_size = xnn_shape_multiply_batch_dims(&input_value->shape, num_nonbatch_dims); + size_t dq_channel_stride = xnn_shape_multiply_trailing_dims(&input_value->shape, num_input_dims - num_nonbatch_dims); switch (opdata->operator_objects[0]->type) { case xnn_operator_type_convert_nc_f16_qd8: { - // Channel stride depends on number of non batch dims. - const uint32_t output_id = opdata->outputs[0]; - assert(output_id < num_values); - const struct xnn_value* output_value = values + output_id; - const size_t num_nonbatch_dims = output_value->quantization.num_nonbatch_dims; - const size_t dq_batch_size = xnn_shape_multiply_batch_dims(&input_value->shape, num_nonbatch_dims); - const size_t dq_channel_stride = xnn_shape_multiply_trailing_dims(&input_value->shape, num_input_dims - num_nonbatch_dims); status = xnn_reshape_convert_nc_f16_qd8( opdata->operator_objects[0], dq_batch_size, @@ -117,13 +125,6 @@ static enum xnn_status reshape_convert_operator( break; } case xnn_operator_type_convert_nc_f32_qd8: { - // Channel stride depends on number of non batch dims. - const uint32_t output_id = opdata->outputs[0]; - assert(output_id < num_values); - const struct xnn_value* output_value = values + output_id; - const size_t num_nonbatch_dims = output_value->quantization.num_nonbatch_dims; - const size_t dq_batch_size = xnn_shape_multiply_batch_dims(&input_value->shape, num_nonbatch_dims); - const size_t dq_channel_stride = xnn_shape_multiply_trailing_dims(&input_value->shape, num_input_dims - num_nonbatch_dims); status = xnn_reshape_convert_nc_f32_qd8( opdata->operator_objects[0], dq_batch_size, @@ -131,16 +132,33 @@ static enum xnn_status reshape_convert_operator( threadpool); break; } + case xnn_operator_type_convert_nc_f32_qdu8: { + status = xnn_reshape_convert_nc_f32_qdu8( + opdata->operator_objects[0], + dq_batch_size, + /*channels=*/dq_channel_stride, /*input_stride=*/dq_channel_stride, /*output_stride=*/dq_channel_stride, + threadpool); + break; + } + case xnn_operator_type_convert_nc_f16_qdu8: { + status = xnn_reshape_convert_nc_f16_qdu8( + opdata->operator_objects[0], + dq_batch_size, + /*channels=*/dq_channel_stride, /*input_stride=*/dq_channel_stride, /*output_stride=*/dq_channel_stride, + threadpool); + break; + } case xnn_operator_type_convert_nc_f32_qp8: { - const size_t num_nonbatch_dims = 1; - const size_t dq_batch_size = - xnn_shape_multiply_batch_dims(&input_value->shape, num_nonbatch_dims); - const size_t dq_channel_stride = xnn_shape_multiply_trailing_dims( - &input_value->shape, num_input_dims - num_nonbatch_dims); + size_t num_groups = xnn_shape_multiply_batch_dims(&input_value->shape, 2); + size_t batch_size = input_value->shape.dim[num_input_dims - 2]; + const size_t channels = input_value->shape.dim[num_input_dims - 1]; + if (output_value->squash_groups) { + batch_size *= num_groups; + num_groups = 1; + } status = xnn_reshape_convert_nc_f32_qp8( - opdata->operator_objects[0], dq_batch_size, - /*channels=*/dq_channel_stride, /*input_stride=*/dq_channel_stride, - threadpool); + opdata->operator_objects[0], num_groups, batch_size, channels, + /*input_stride=*/channels, threadpool); break; } default: @@ -200,6 +218,26 @@ static enum xnn_status setup_convert_operator( output_data, quantization_params); } + case xnn_operator_type_convert_nc_f16_qdu8: + { + void* quantization_params = output_value->quantization.dynamic_params; + assert(quantization_params != NULL); + return xnn_setup_convert_nc_f16_qdu8( + opdata->operator_objects[0], + input_data, + output_data, + quantization_params); + } + case xnn_operator_type_convert_nc_f32_qdu8: + { + void* quantization_params = output_value->quantization.dynamic_params; + assert(quantization_params != NULL); + return xnn_setup_convert_nc_f32_qdu8( + opdata->operator_objects[0], + input_data, + output_data, + quantization_params); + } case xnn_operator_type_convert_nc_f32_qp8: return xnn_setup_convert_nc_f32_qp8(opdata->operator_objects[0], input_data, output_data); @@ -359,38 +397,11 @@ enum xnn_status xnn_define_unary( if (type == xnn_unary_convert) { // Some convert types are not elementwise ops, handle them now. if (output_value->datatype == xnn_datatype_qdint8 || - // TODO(b/340399245) - Uncomment once we have full support for `qpint8`. - // output_value->datatype == xnn_datatype_qpint8 || - false) { - // Coerce the input from `xnn_datatype_qdint8` to `xnn_datatype_qpint8` if we - // know that we're converting for a GEMM and `qp8_f32_*` kernels are - // available. - // TODO(b/340399245) - Remove xnn_init_qp8_f32_qc4w_gemm_config check once we - // have full qp8 support. - bool pack_activation_for_qc4w = ( - (flags & XNN_FLAG_MAYBE_PACK_FOR_GEMM) && - xnn_init_qp8_f32_qc4w_gemm_config() != NULL - ); - bool pack_activation_for_qb4w = ( - (flags & XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM) && - xnn_init_qp8_f32_qb4w_gemm_config() != NULL - ); - if ((pack_activation_for_qb4w || pack_activation_for_qc4w) && - input_value->datatype == xnn_datatype_fp32 && - output_value->datatype == xnn_datatype_qdint8) { - xnn_log_debug("Coercing type of output ID #%" PRIu32 - " of %s operator from `%s` to `%s`.", - output_id, xnn_node_type_to_string(xnn_node_type_convert), - xnn_datatype_to_string(output_value->datatype), - xnn_datatype_to_string(xnn_datatype_qpint8)); - subgraph->values[output_id].datatype = xnn_datatype_qpint8; - } - + output_value->datatype == xnn_datatype_qduint8) { struct xnn_node* node = xnn_subgraph_new_node(subgraph); if (node == NULL) { return xnn_status_out_of_memory; } - xnn_init_convert_node(node, input_id, output_id, flags); return xnn_status_success; } diff --git a/src/tensor.c b/src/tensor.c index ef7bc83dd46d..96a15e84cee2 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -15,8 +15,6 @@ #include "xnnpack.h" #include "xnnpack/allocation-type.h" #include "xnnpack/common.h" -#include "xnnpack/config-types.h" -#include "xnnpack/config.h" #include "xnnpack/datatype.h" #include "xnnpack/log.h" #include "xnnpack/math.h" @@ -614,9 +612,16 @@ size_t xnn_tensor_get_size(const struct xnn_value* value) // Special handling for packed quantized types. if (value->datatype == xnn_datatype_qpint8) { - const size_t m = xnn_shape_multiply_batch_dims(&value->shape, 1); + assert(value->gemm_config != NULL); + size_t num_groups = xnn_shape_multiply_batch_dims(&value->shape, 2); + size_t m = value->shape.dim[value->shape.num_dims - 2]; const size_t k = value->shape.dim[value->shape.num_dims - 1]; - return xnn_x8_packq_f32qp8_gemm_packed_size(m, k); + if (value->squash_groups) { + m *= num_groups; + num_groups = 1; + } + return num_groups * + xnn_x8_packq_f32qp8_gemm_packed_size(value->gemm_config, m, k); } uint64_t size_bits = xnn_datatype_size_bits(value->datatype); @@ -633,7 +638,8 @@ size_t xnn_tensor_get_size(const struct xnn_value* value) size_t xnn_tensor_get_dynamic_quant_param_size(const struct xnn_value* value) { switch (value->datatype) { - case xnn_datatype_qdint8: { + case xnn_datatype_qdint8: + case xnn_datatype_qduint8: { const size_t batch_dims_size = xnn_shape_multiply_batch_dims( &value->shape, value->quantization.num_nonbatch_dims); return batch_dims_size * sizeof(struct xnn_quantization_params); diff --git a/src/x24-transposec/x24-transposec-4x4-ssse3.c b/src/x24-transposec/x24-transposec-4x4-ssse3.c index 42565ff40331..ab8a13098d0d 100644 --- a/src/x24-transposec/x24-transposec-4x4-ssse3.c +++ b/src/x24-transposec/x24-transposec-4x4-ssse3.c @@ -20,12 +20,12 @@ void xnn_x24_transposec_ukernel__4x4_ssse3( size_t block_width, size_t block_height) XNN_OOB_READS { - static const uint8_t pos0[16] = {0, 4, 8, 2, 6, 10, 1, 5, 9, 3, 7, 11, -1, -1, -1, -1}; - static const uint8_t pos1[16] = {4, 8, 12, 6, 10, 14, 5, 9, 13, 7, 11, 15, -1, -1, -1, -1}; - static const uint8_t pos2[16] = {12, -1, -1, 14, -1, -1, 13, -1, -1, 15, -1, -1, -1, -1, -1, -1}; - static const uint8_t pos3[16] = {-1, 0, 4, -1, 2, 6, -1, 1, 5, -1, 3, 7, -1, -1, -1, -1}; - static const uint8_t pos4[16] = {8, 12, -1, 10, 14, -1, 9, 13, -1, 11, 15, -1, -1, -1, -1, -1}; - static const uint8_t pos5[16] = {-1, -1, 0, -1, -1, 2, -1, -1, 1, -1, -1, 3, -1, -1, -1, -1}; + XNN_ALIGN(16) static const uint8_t pos0[16] = {0, 4, 8, 2, 6, 10, 1, 5, 9, 3, 7, 11, -1, -1, -1, -1}; + XNN_ALIGN(16) static const uint8_t pos1[16] = {4, 8, 12, 6, 10, 14, 5, 9, 13, 7, 11, 15, -1, -1, -1, -1}; + XNN_ALIGN(16) static const uint8_t pos2[16] = {12, -1, -1, 14, -1, -1, 13, -1, -1, 15, -1, -1, -1, -1, -1, -1}; + XNN_ALIGN(16) static const uint8_t pos3[16] = {-1, 0, 4, -1, 2, 6, -1, 1, 5, -1, 3, 7, -1, -1, -1, -1}; + XNN_ALIGN(16) static const uint8_t pos4[16] = {8, 12, -1, 10, 14, -1, 9, 13, -1, 11, 15, -1, -1, -1, -1, -1}; + XNN_ALIGN(16) static const uint8_t pos5[16] = {-1, -1, 0, -1, -1, 2, -1, -1, 1, -1, -1, 3, -1, -1, -1, -1}; assert(output_stride >= block_height * 3); assert(input_stride >= block_width * 3); diff --git a/src/x32-pack-lh/x32-packlh-neonsme2.c b/src/x32-pack-lh/x32-packlh-neonsme2.c index 999894e798c8..4f0efb4cee80 100644 --- a/src/x32-pack-lh/x32-packlh-neonsme2.c +++ b/src/x32-pack-lh/x32-packlh-neonsme2.c @@ -30,8 +30,12 @@ void xnn_x32_pack_lh_ukernel__neonsme2(size_t m, size_t k, size_t mr, size_t lhs_stride, void* XNN_RESTRICT lhs_packed) { #if XNN_ENABLE_KLEIDIAI - kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, m_idx_start, lhs, - lhs_stride, lhs_packed); + if (m == 1) { + memcpy(lhs_packed, lhs, sizeof(float) * k); + } else { + kai_run_lhs_pack_f32p2vlx1_f32_sme(m, k, mr, kr, sr, m_idx_start, lhs, + lhs_stride, lhs_packed); + } #else assert("Not compiled with XNN_ENABLE_KLEIDIAI" && 0); #endif // XNN_ENABLE_KLEIDIAI diff --git a/src/x32-packw/gen/x32-packw-gio-hvx-u2.c b/src/x32-packw/gen/x32-packw-gio-hvx-u2.c new file mode 100644 index 000000000000..7d2378b98fbd --- /dev/null +++ b/src/x32-packw/gen/x32-packw-gio-hvx-u2.c @@ -0,0 +1,495 @@ +// Auto-generated file. Do not edit! +// Template: src/x32-packw/gio-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/simd/s32-hvx.h" + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" + +static XNN_INLINE xnn_simd_s32_t +xnn_load_tail_no_oob_s32(const int32_t* input, size_t num_elements) { + assert(num_elements <= xnn_simd_size_s32); + int32_t buf[32]; + for (size_t i = 0; i < num_elements; ++i) { + buf[i] = input[i]; + } + return xnn_loadu_s32((const int32_t*) &buf[0]); +} + + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x32__hvx_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 32); // This kernel is for NR=32 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 32 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 32; n -= 32) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + xnn_storeu_s32(packed_w + 0, vb0); + b += 32; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 32; + + // KC main loop 2x32 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 32, v0_1); + w += k_stride * 2; + packed_w += 64; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 32; + } + w = w - kc * k_stride + 32; // Advance to next column of 32 int32_t + } + + // NC remainder (1..31) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 31); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = n; + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + xnn_storeu_s32(packed_w + 0, vb0); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 32; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 32; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x64__hvx_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 64); // This kernel is for NR=64 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 64 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 64; n -= 64) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 32); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 32, vb1); + b += 64; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 32, vzero); + } + packed_w += 64; + + // KC main loop 2x64 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 32 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 32 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 32, v1_0); + xnn_storeu_s32(packed_w + 64, v0_1); + xnn_storeu_s32(packed_w + 96, v1_1); + w += k_stride * 2; + packed_w += 128; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 32); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 32, v1); + w += k_stride; + packed_w += 64; + } + w = w - kc * k_stride + 64; // Advance to next column of 64 int32_t + } + + // NC remainder (1..63) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 63); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 32 ? 32 : n - 0); + const size_t vcount1 = (int) (n - 32) < 0 ? 0 : ((int) (n - 32) > 32 ? 32 : n - 32); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 32, vcount1); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 32, vb1); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 32, vzero); + } + packed_w += 64; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 32, vcount1); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 32, v1); + w += k_stride; + packed_w += 64; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x96__hvx_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 96); // This kernel is for NR=96 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 96 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 96; n -= 96) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 32); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 64); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 32, vb1); + xnn_storeu_s32(packed_w + 64, vb2); + b += 96; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 32, vzero); + xnn_storeu_s32(packed_w + 64, vzero); + } + packed_w += 96; + + // KC main loop 2x96 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 32 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 64 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 32 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 64 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 32, v1_0); + xnn_storeu_s32(packed_w + 64, v2_0); + xnn_storeu_s32(packed_w + 96, v0_1); + xnn_storeu_s32(packed_w + 128, v1_1); + xnn_storeu_s32(packed_w + 160, v2_1); + w += k_stride * 2; + packed_w += 192; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 32); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 64); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 32, v1); + xnn_storeu_s32(packed_w + 64, v2); + w += k_stride; + packed_w += 96; + } + w = w - kc * k_stride + 96; // Advance to next column of 96 int32_t + } + + // NC remainder (1..95) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 95); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 32 ? 32 : n - 0); + const size_t vcount1 = (int) (n - 32) < 0 ? 0 : ((int) (n - 32) > 32 ? 32 : n - 32); + const size_t vcount2 = (int) (n - 64) < 0 ? 0 : ((int) (n - 64) > 32 ? 32 : n - 64); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 32, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 64, vcount2); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 32, vb1); + xnn_storeu_s32(packed_w + 64, vb2); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 32, vzero); + xnn_storeu_s32(packed_w + 64, vzero); + } + packed_w += 96; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 32, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 64, vcount2); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 32, v1); + xnn_storeu_s32(packed_w + 64, v2); + w += k_stride; + packed_w += 96; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x128__hvx_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 128); // This kernel is for NR=128 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 128 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 128; n -= 128) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 32); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 64); + const xnn_simd_s32_t vb3 = xnn_loadu_s32(b + 96); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 32, vb1); + xnn_storeu_s32(packed_w + 64, vb2); + xnn_storeu_s32(packed_w + 96, vb3); + b += 128; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 32, vzero); + xnn_storeu_s32(packed_w + 64, vzero); + xnn_storeu_s32(packed_w + 96, vzero); + } + packed_w += 128; + + // KC main loop 2x128 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 32 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 64 + 0 * k_stride); + const xnn_simd_s32_t v3_0 = xnn_loadu_s32(w + 96 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 32 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 64 + 1 * k_stride); + const xnn_simd_s32_t v3_1 = xnn_loadu_s32(w + 96 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 32, v1_0); + xnn_storeu_s32(packed_w + 64, v2_0); + xnn_storeu_s32(packed_w + 96, v3_0); + xnn_storeu_s32(packed_w + 128, v0_1); + xnn_storeu_s32(packed_w + 160, v1_1); + xnn_storeu_s32(packed_w + 192, v2_1); + xnn_storeu_s32(packed_w + 224, v3_1); + w += k_stride * 2; + packed_w += 256; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 32); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 64); + const xnn_simd_s32_t v3 = xnn_loadu_s32(w + 96); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 32, v1); + xnn_storeu_s32(packed_w + 64, v2); + xnn_storeu_s32(packed_w + 96, v3); + w += k_stride; + packed_w += 128; + } + w = w - kc * k_stride + 128; // Advance to next column of 128 int32_t + } + + // NC remainder (1..127) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 127); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 32 ? 32 : n - 0); + const size_t vcount1 = (int) (n - 32) < 0 ? 0 : ((int) (n - 32) > 32 ? 32 : n - 32); + const size_t vcount2 = (int) (n - 64) < 0 ? 0 : ((int) (n - 64) > 32 ? 32 : n - 64); + const size_t vcount3 = (int) (n - 96) < 0 ? 0 : ((int) (n - 96) > 32 ? 32 : n - 96); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 32, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 64, vcount2); + const xnn_simd_s32_t vb3 = xnn_load_tail_no_oob_s32(b + 96, vcount3); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 32, vb1); + xnn_storeu_s32(packed_w + 64, vb2); + xnn_storeu_s32(packed_w + 96, vb3); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 32, vzero); + xnn_storeu_s32(packed_w + 64, vzero); + xnn_storeu_s32(packed_w + 96, vzero); + } + packed_w += 128; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 32, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 64, vcount2); + const xnn_simd_s32_t v3 = xnn_load_tail_no_oob_s32(w + 96, vcount3); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 32, v1); + xnn_storeu_s32(packed_w + 64, v2); + xnn_storeu_s32(packed_w + 96, v3); + w += k_stride; + packed_w += 128; + } + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x32-packw/gen/x32-packw-gio-neon-u2.c b/src/x32-packw/gen/x32-packw-gio-neon-u2.c new file mode 100644 index 000000000000..5b5cf42b770f --- /dev/null +++ b/src/x32-packw/gen/x32-packw-gio-neon-u2.c @@ -0,0 +1,495 @@ +// Auto-generated file. Do not edit! +// Template: src/x32-packw/gio-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/simd/s32-neon.h" + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" + +static XNN_INLINE xnn_simd_s32_t +xnn_load_tail_no_oob_s32(const int32_t* input, size_t num_elements) { + assert(num_elements <= xnn_simd_size_s32); + int32_t buf[4]; + for (size_t i = 0; i < num_elements; ++i) { + buf[i] = input[i]; + } + return xnn_loadu_s32((const int32_t*) &buf[0]); +} + + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x4__neon_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 4); // This kernel is for NR=4 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 4 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 4; n -= 4) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + xnn_storeu_s32(packed_w + 0, vb0); + b += 4; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 4; + + // KC main loop 2x4 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v0_1); + w += k_stride * 2; + packed_w += 8; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 4; + } + w = w - kc * k_stride + 4; // Advance to next column of 4 int32_t + } + + // NC remainder (1..3) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 3); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = n; + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + xnn_storeu_s32(packed_w + 0, vb0); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 4; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 4; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x8__neon_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); // This kernel is for NR=8 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 8 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 8; n -= 8) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + b += 8; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + } + packed_w += 8; + + // KC main loop 2x8 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v0_1); + xnn_storeu_s32(packed_w + 12, v1_1); + w += k_stride * 2; + packed_w += 16; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + w += k_stride; + packed_w += 8; + } + w = w - kc * k_stride + 8; // Advance to next column of 8 int32_t + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 7); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + } + packed_w += 8; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + w += k_stride; + packed_w += 8; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x12__neon_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 12); // This kernel is for NR=12 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 12 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 12; n -= 12) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 8); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + b += 12; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + } + packed_w += 12; + + // KC main loop 2x12 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 8 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 8 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v2_0); + xnn_storeu_s32(packed_w + 12, v0_1); + xnn_storeu_s32(packed_w + 16, v1_1); + xnn_storeu_s32(packed_w + 20, v2_1); + w += k_stride * 2; + packed_w += 24; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 8); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + w += k_stride; + packed_w += 12; + } + w = w - kc * k_stride + 12; // Advance to next column of 12 int32_t + } + + // NC remainder (1..11) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 11); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + const size_t vcount2 = (int) (n - 8) < 0 ? 0 : ((int) (n - 8) > 4 ? 4 : n - 8); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 8, vcount2); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + } + packed_w += 12; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 8, vcount2); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + w += k_stride; + packed_w += 12; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x16__neon_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); // This kernel is for NR=16 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 16 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 16; n -= 16) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 8); + const xnn_simd_s32_t vb3 = xnn_loadu_s32(b + 12); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + xnn_storeu_s32(packed_w + 12, vb3); + b += 16; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + xnn_storeu_s32(packed_w + 12, vzero); + } + packed_w += 16; + + // KC main loop 2x16 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 8 + 0 * k_stride); + const xnn_simd_s32_t v3_0 = xnn_loadu_s32(w + 12 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 8 + 1 * k_stride); + const xnn_simd_s32_t v3_1 = xnn_loadu_s32(w + 12 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v2_0); + xnn_storeu_s32(packed_w + 12, v3_0); + xnn_storeu_s32(packed_w + 16, v0_1); + xnn_storeu_s32(packed_w + 20, v1_1); + xnn_storeu_s32(packed_w + 24, v2_1); + xnn_storeu_s32(packed_w + 28, v3_1); + w += k_stride * 2; + packed_w += 32; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 8); + const xnn_simd_s32_t v3 = xnn_loadu_s32(w + 12); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + xnn_storeu_s32(packed_w + 12, v3); + w += k_stride; + packed_w += 16; + } + w = w - kc * k_stride + 16; // Advance to next column of 16 int32_t + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 15); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + const size_t vcount2 = (int) (n - 8) < 0 ? 0 : ((int) (n - 8) > 4 ? 4 : n - 8); + const size_t vcount3 = (int) (n - 12) < 0 ? 0 : ((int) (n - 12) > 4 ? 4 : n - 12); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 8, vcount2); + const xnn_simd_s32_t vb3 = xnn_load_tail_no_oob_s32(b + 12, vcount3); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + xnn_storeu_s32(packed_w + 12, vb3); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + xnn_storeu_s32(packed_w + 12, vzero); + } + packed_w += 16; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 8, vcount2); + const xnn_simd_s32_t v3 = xnn_load_tail_no_oob_s32(w + 12, vcount3); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + xnn_storeu_s32(packed_w + 12, v3); + w += k_stride; + packed_w += 16; + } + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x32-packw/gen/x32-packw-gio-sse41-u2.c b/src/x32-packw/gen/x32-packw-gio-sse41-u2.c new file mode 100644 index 000000000000..5f185fd6d28d --- /dev/null +++ b/src/x32-packw/gen/x32-packw-gio-sse41-u2.c @@ -0,0 +1,495 @@ +// Auto-generated file. Do not edit! +// Template: src/x32-packw/gio-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/simd/s32-sse41.h" + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" + +static XNN_INLINE xnn_simd_s32_t +xnn_load_tail_no_oob_s32(const int32_t* input, size_t num_elements) { + assert(num_elements <= xnn_simd_size_s32); + int32_t buf[4]; + for (size_t i = 0; i < num_elements; ++i) { + buf[i] = input[i]; + } + return xnn_loadu_s32((const int32_t*) &buf[0]); +} + + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x4__sse41_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 4); // This kernel is for NR=4 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 4 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 4; n -= 4) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + xnn_storeu_s32(packed_w + 0, vb0); + b += 4; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 4; + + // KC main loop 2x4 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v0_1); + w += k_stride * 2; + packed_w += 8; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 4; + } + w = w - kc * k_stride + 4; // Advance to next column of 4 int32_t + } + + // NC remainder (1..3) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 3); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = n; + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + xnn_storeu_s32(packed_w + 0, vb0); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 4; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 4; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x8__sse41_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); // This kernel is for NR=8 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 8 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 8; n -= 8) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + b += 8; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + } + packed_w += 8; + + // KC main loop 2x8 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v0_1); + xnn_storeu_s32(packed_w + 12, v1_1); + w += k_stride * 2; + packed_w += 16; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + w += k_stride; + packed_w += 8; + } + w = w - kc * k_stride + 8; // Advance to next column of 8 int32_t + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 7); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + } + packed_w += 8; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + w += k_stride; + packed_w += 8; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x12__sse41_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 12); // This kernel is for NR=12 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 12 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 12; n -= 12) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 8); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + b += 12; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + } + packed_w += 12; + + // KC main loop 2x12 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 8 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 8 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v2_0); + xnn_storeu_s32(packed_w + 12, v0_1); + xnn_storeu_s32(packed_w + 16, v1_1); + xnn_storeu_s32(packed_w + 20, v2_1); + w += k_stride * 2; + packed_w += 24; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 8); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + w += k_stride; + packed_w += 12; + } + w = w - kc * k_stride + 12; // Advance to next column of 12 int32_t + } + + // NC remainder (1..11) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 11); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + const size_t vcount2 = (int) (n - 8) < 0 ? 0 : ((int) (n - 8) > 4 ? 4 : n - 8); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 8, vcount2); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + } + packed_w += 12; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 8, vcount2); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + w += k_stride; + packed_w += 12; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x16__sse41_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); // This kernel is for NR=16 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 16 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 16; n -= 16) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 8); + const xnn_simd_s32_t vb3 = xnn_loadu_s32(b + 12); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + xnn_storeu_s32(packed_w + 12, vb3); + b += 16; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + xnn_storeu_s32(packed_w + 12, vzero); + } + packed_w += 16; + + // KC main loop 2x16 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 8 + 0 * k_stride); + const xnn_simd_s32_t v3_0 = xnn_loadu_s32(w + 12 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 8 + 1 * k_stride); + const xnn_simd_s32_t v3_1 = xnn_loadu_s32(w + 12 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v2_0); + xnn_storeu_s32(packed_w + 12, v3_0); + xnn_storeu_s32(packed_w + 16, v0_1); + xnn_storeu_s32(packed_w + 20, v1_1); + xnn_storeu_s32(packed_w + 24, v2_1); + xnn_storeu_s32(packed_w + 28, v3_1); + w += k_stride * 2; + packed_w += 32; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 8); + const xnn_simd_s32_t v3 = xnn_loadu_s32(w + 12); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + xnn_storeu_s32(packed_w + 12, v3); + w += k_stride; + packed_w += 16; + } + w = w - kc * k_stride + 16; // Advance to next column of 16 int32_t + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 15); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + const size_t vcount2 = (int) (n - 8) < 0 ? 0 : ((int) (n - 8) > 4 ? 4 : n - 8); + const size_t vcount3 = (int) (n - 12) < 0 ? 0 : ((int) (n - 12) > 4 ? 4 : n - 12); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 8, vcount2); + const xnn_simd_s32_t vb3 = xnn_load_tail_no_oob_s32(b + 12, vcount3); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + xnn_storeu_s32(packed_w + 12, vb3); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + xnn_storeu_s32(packed_w + 12, vzero); + } + packed_w += 16; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 8, vcount2); + const xnn_simd_s32_t v3 = xnn_load_tail_no_oob_s32(w + 12, vcount3); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + xnn_storeu_s32(packed_w + 12, v3); + w += k_stride; + packed_w += 16; + } + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x32-packw/gen/x32-packw-gio-wasmsimd-u2.c b/src/x32-packw/gen/x32-packw-gio-wasmsimd-u2.c new file mode 100644 index 000000000000..60743747f216 --- /dev/null +++ b/src/x32-packw/gen/x32-packw-gio-wasmsimd-u2.c @@ -0,0 +1,495 @@ +// Auto-generated file. Do not edit! +// Template: src/x32-packw/gio-simd.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "xnnpack/simd/s32-wasmsimd.h" + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" + +static XNN_INLINE xnn_simd_s32_t +xnn_load_tail_no_oob_s32(const int32_t* input, size_t num_elements) { + assert(num_elements <= xnn_simd_size_s32); + int32_t buf[4]; + for (size_t i = 0; i < num_elements; ++i) { + buf[i] = input[i]; + } + return xnn_loadu_s32((const int32_t*) &buf[0]); +} + + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x4__wasmsimd_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 4); // This kernel is for NR=4 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 4 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 4; n -= 4) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + xnn_storeu_s32(packed_w + 0, vb0); + b += 4; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 4; + + // KC main loop 2x4 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v0_1); + w += k_stride * 2; + packed_w += 8; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 4; + } + w = w - kc * k_stride + 4; // Advance to next column of 4 int32_t + } + + // NC remainder (1..3) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 3); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = n; + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + xnn_storeu_s32(packed_w + 0, vb0); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + } + packed_w += 4; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + xnn_storeu_s32(packed_w + 0, v0); + w += k_stride; + packed_w += 4; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x8__wasmsimd_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); // This kernel is for NR=8 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 8 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 8; n -= 8) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + b += 8; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + } + packed_w += 8; + + // KC main loop 2x8 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v0_1); + xnn_storeu_s32(packed_w + 12, v1_1); + w += k_stride * 2; + packed_w += 16; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + w += k_stride; + packed_w += 8; + } + w = w - kc * k_stride + 8; // Advance to next column of 8 int32_t + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 7); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + } + packed_w += 8; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + w += k_stride; + packed_w += 8; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x12__wasmsimd_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 12); // This kernel is for NR=12 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 12 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 12; n -= 12) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 8); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + b += 12; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + } + packed_w += 12; + + // KC main loop 2x12 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 8 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 8 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v2_0); + xnn_storeu_s32(packed_w + 12, v0_1); + xnn_storeu_s32(packed_w + 16, v1_1); + xnn_storeu_s32(packed_w + 20, v2_1); + w += k_stride * 2; + packed_w += 24; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 8); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + w += k_stride; + packed_w += 12; + } + w = w - kc * k_stride + 12; // Advance to next column of 12 int32_t + } + + // NC remainder (1..11) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 11); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + const size_t vcount2 = (int) (n - 8) < 0 ? 0 : ((int) (n - 8) > 4 ? 4 : n - 8); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 8, vcount2); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + } + packed_w += 12; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 8, vcount2); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + w += k_stride; + packed_w += 12; + } + } + weights += nc * kc; + } while (--g != 0); +} + +// Pack pre-transposed weights (GIO) for use by f32-gemm +void xnn_x32_packw_gemm_gio_ukernel_x16__wasmsimd_u2( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); // This kernel is for NR=16 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of 16 + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= 16; n -= 16) { + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_loadu_s32(b + 0); + const xnn_simd_s32_t vb1 = xnn_loadu_s32(b + 4); + const xnn_simd_s32_t vb2 = xnn_loadu_s32(b + 8); + const xnn_simd_s32_t vb3 = xnn_loadu_s32(b + 12); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + xnn_storeu_s32(packed_w + 12, vb3); + b += 16; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + xnn_storeu_s32(packed_w + 12, vzero); + } + packed_w += 16; + + // KC main loop 2x16 + size_t k = kc; + for (; k >= 2; k -= 2) { + const xnn_simd_s32_t v0_0 = xnn_loadu_s32(w + 0 + 0 * k_stride); + const xnn_simd_s32_t v1_0 = xnn_loadu_s32(w + 4 + 0 * k_stride); + const xnn_simd_s32_t v2_0 = xnn_loadu_s32(w + 8 + 0 * k_stride); + const xnn_simd_s32_t v3_0 = xnn_loadu_s32(w + 12 + 0 * k_stride); + const xnn_simd_s32_t v0_1 = xnn_loadu_s32(w + 0 + 1 * k_stride); + const xnn_simd_s32_t v1_1 = xnn_loadu_s32(w + 4 + 1 * k_stride); + const xnn_simd_s32_t v2_1 = xnn_loadu_s32(w + 8 + 1 * k_stride); + const xnn_simd_s32_t v3_1 = xnn_loadu_s32(w + 12 + 1 * k_stride); + xnn_storeu_s32(packed_w + 0, v0_0); + xnn_storeu_s32(packed_w + 4, v1_0); + xnn_storeu_s32(packed_w + 8, v2_0); + xnn_storeu_s32(packed_w + 12, v3_0); + xnn_storeu_s32(packed_w + 16, v0_1); + xnn_storeu_s32(packed_w + 20, v1_1); + xnn_storeu_s32(packed_w + 24, v2_1); + xnn_storeu_s32(packed_w + 28, v3_1); + w += k_stride * 2; + packed_w += 32; + } + + // KC remainder loop + for (; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_loadu_s32(w + 0); + const xnn_simd_s32_t v1 = xnn_loadu_s32(w + 4); + const xnn_simd_s32_t v2 = xnn_loadu_s32(w + 8); + const xnn_simd_s32_t v3 = xnn_loadu_s32(w + 12); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + xnn_storeu_s32(packed_w + 12, v3); + w += k_stride; + packed_w += 16; + } + w = w - kc * k_stride + 16; // Advance to next column of 16 int32_t + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 15); + + // Prepare count for valid 32-bit elements (depends on n). + const size_t vcount0 = (int) (n - 0) < 0 ? 0 : ((int) (n - 0) > 4 ? 4 : n - 0); + const size_t vcount1 = (int) (n - 4) < 0 ? 0 : ((int) (n - 4) > 4 ? 4 : n - 4); + const size_t vcount2 = (int) (n - 8) < 0 ? 0 : ((int) (n - 8) > 4 ? 4 : n - 8); + const size_t vcount3 = (int) (n - 12) < 0 ? 0 : ((int) (n - 12) > 4 ? 4 : n - 12); + + if XNN_LIKELY(b != NULL) { + const xnn_simd_s32_t vb0 = xnn_load_tail_no_oob_s32(b + 0, vcount0); + const xnn_simd_s32_t vb1 = xnn_load_tail_no_oob_s32(b + 4, vcount1); + const xnn_simd_s32_t vb2 = xnn_load_tail_no_oob_s32(b + 8, vcount2); + const xnn_simd_s32_t vb3 = xnn_load_tail_no_oob_s32(b + 12, vcount3); + xnn_storeu_s32(packed_w + 0, vb0); + xnn_storeu_s32(packed_w + 4, vb1); + xnn_storeu_s32(packed_w + 8, vb2); + xnn_storeu_s32(packed_w + 12, vb3); + b += n; + } else { + xnn_storeu_s32(packed_w + 0, vzero); + xnn_storeu_s32(packed_w + 4, vzero); + xnn_storeu_s32(packed_w + 8, vzero); + xnn_storeu_s32(packed_w + 12, vzero); + } + packed_w += 16; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const xnn_simd_s32_t v0 = xnn_load_tail_no_oob_s32(w + 0, vcount0); + const xnn_simd_s32_t v1 = xnn_load_tail_no_oob_s32(w + 4, vcount1); + const xnn_simd_s32_t v2 = xnn_load_tail_no_oob_s32(w + 8, vcount2); + const xnn_simd_s32_t v3 = xnn_load_tail_no_oob_s32(w + 12, vcount3); + xnn_storeu_s32(packed_w + 0, v0); + xnn_storeu_s32(packed_w + 4, v1); + xnn_storeu_s32(packed_w + 8, v2); + xnn_storeu_s32(packed_w + 12, v3); + w += k_stride; + packed_w += 16; + } + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x32-packw/gen/x32-packw-x16-gemm-gio-scalar.c b/src/x32-packw/gen/x32-packw-x16-gemm-gio-scalar.c index ab8c8a2cbacf..7cd8f7dfd12f 100644 --- a/src/x32-packw/gen/x32-packw-x16-gemm-gio-scalar.c +++ b/src/x32-packw/gen/x32-packw-x16-gemm-gio-scalar.c @@ -39,66 +39,106 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__scalar( assert(weights != NULL); assert(packed_weights != NULL); - const float* b = (const float*) bias; - float* packed_w = (float*) packed_weights; + const uint32_t* b = bias; + uint32_t* packed_w = packed_weights; do { // NC main loop multiple of 16 - const float* w = (const float*) weights; + const uint32_t* w = weights; size_t n = nc; for (; n >= 16; n -= 16) { if XNN_LIKELY(b != NULL) { - const uint64_t v0 = ((const uint64_t*)b)[0]; - const uint64_t v1 = ((const uint64_t*)b)[1]; - const uint64_t v2 = ((const uint64_t*)b)[2]; - const uint64_t v3 = ((const uint64_t*)b)[3]; - const uint64_t v4 = ((const uint64_t*)b)[4]; - const uint64_t v5 = ((const uint64_t*)b)[5]; - const uint64_t v6 = ((const uint64_t*)b)[6]; - const uint64_t v7 = ((const uint64_t*)b)[7]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; - ((uint64_t*)packed_w)[2] = v2; - ((uint64_t*)packed_w)[3] = v3; - ((uint64_t*)packed_w)[4] = v4; - ((uint64_t*)packed_w)[5] = v5; - ((uint64_t*)packed_w)[6] = v6; - ((uint64_t*)packed_w)[7] = v7; + const uint32_t v0 = b[0]; + const uint32_t v1 = b[1]; + const uint32_t v2 = b[2]; + const uint32_t v3 = b[3]; + const uint32_t v4 = b[4]; + const uint32_t v5 = b[5]; + const uint32_t v6 = b[6]; + const uint32_t v7 = b[7]; + const uint32_t v8 = b[8]; + const uint32_t v9 = b[9]; + const uint32_t v10 = b[10]; + const uint32_t v11 = b[11]; + const uint32_t v12 = b[12]; + const uint32_t v13 = b[13]; + const uint32_t v14 = b[14]; + const uint32_t v15 = b[15]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; + packed_w[4] = v4; + packed_w[5] = v5; + packed_w[6] = v6; + packed_w[7] = v7; + packed_w[8] = v8; + packed_w[9] = v9; + packed_w[10] = v10; + packed_w[11] = v11; + packed_w[12] = v12; + packed_w[13] = v13; + packed_w[14] = v14; + packed_w[15] = v15; b += 16; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; - ((uint64_t*)packed_w)[2] = 0; - ((uint64_t*)packed_w)[3] = 0; - ((uint64_t*)packed_w)[4] = 0; - ((uint64_t*)packed_w)[5] = 0; - ((uint64_t*)packed_w)[6] = 0; - ((uint64_t*)packed_w)[7] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; + packed_w[4] = 0; + packed_w[5] = 0; + packed_w[6] = 0; + packed_w[7] = 0; + packed_w[8] = 0; + packed_w[9] = 0; + packed_w[10] = 0; + packed_w[11] = 0; + packed_w[12] = 0; + packed_w[13] = 0; + packed_w[14] = 0; + packed_w[15] = 0; } packed_w += 16; // KC main loop for (size_t k = kc; k > 0; --k) { - const uint64_t v0 = ((const uint64_t*)w)[0]; - const uint64_t v1 = ((const uint64_t*)w)[1]; - const uint64_t v2 = ((const uint64_t*)w)[2]; - const uint64_t v3 = ((const uint64_t*)w)[3]; - const uint64_t v4 = ((const uint64_t*)w)[4]; - const uint64_t v5 = ((const uint64_t*)w)[5]; - const uint64_t v6 = ((const uint64_t*)w)[6]; - const uint64_t v7 = ((const uint64_t*)w)[7]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; - ((uint64_t*)packed_w)[2] = v2; - ((uint64_t*)packed_w)[3] = v3; - ((uint64_t*)packed_w)[4] = v4; - ((uint64_t*)packed_w)[5] = v5; - ((uint64_t*)packed_w)[6] = v6; - ((uint64_t*)packed_w)[7] = v7; + const uint32_t v0 = w[0]; + const uint32_t v1 = w[1]; + const uint32_t v2 = w[2]; + const uint32_t v3 = w[3]; + const uint32_t v4 = w[4]; + const uint32_t v5 = w[5]; + const uint32_t v6 = w[6]; + const uint32_t v7 = w[7]; + const uint32_t v8 = w[8]; + const uint32_t v9 = w[9]; + const uint32_t v10 = w[10]; + const uint32_t v11 = w[11]; + const uint32_t v12 = w[12]; + const uint32_t v13 = w[13]; + const uint32_t v14 = w[14]; + const uint32_t v15 = w[15]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; + packed_w[4] = v4; + packed_w[5] = v5; + packed_w[6] = v6; + packed_w[7] = v7; + packed_w[8] = v8; + packed_w[9] = v9; + packed_w[10] = v10; + packed_w[11] = v11; + packed_w[12] = v12; + packed_w[13] = v13; + packed_w[14] = v14; + packed_w[15] = v15; w += k_stride; packed_w += 16; } - w = w - kc * k_stride + 16; // Advance to next column of 16 floats + w = w - kc * k_stride + 16; // Advance to next column of 16 uint32_t } // NC remainder (1..15) @@ -112,14 +152,22 @@ void xnn_x32_packw_gemm_gio_ukernel_x16__scalar( } b += n; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; - ((uint64_t*)packed_w)[2] = 0; - ((uint64_t*)packed_w)[3] = 0; - ((uint64_t*)packed_w)[4] = 0; - ((uint64_t*)packed_w)[5] = 0; - ((uint64_t*)packed_w)[6] = 0; - ((uint64_t*)packed_w)[7] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; + packed_w[4] = 0; + packed_w[5] = 0; + packed_w[6] = 0; + packed_w[7] = 0; + packed_w[8] = 0; + packed_w[9] = 0; + packed_w[10] = 0; + packed_w[11] = 0; + packed_w[12] = 0; + packed_w[13] = 0; + packed_w[14] = 0; + packed_w[15] = 0; } packed_w += 16; diff --git a/src/x32-packw/gen/x32-packw-x32-gemm-gio-scalar.c b/src/x32-packw/gen/x32-packw-x32-gemm-gio-scalar.c index 00284e3305ad..86fea06d7661 100644 --- a/src/x32-packw/gen/x32-packw-x32-gemm-gio-scalar.c +++ b/src/x32-packw/gen/x32-packw-x32-gemm-gio-scalar.c @@ -39,106 +39,186 @@ void xnn_x32_packw_gemm_gio_ukernel_x32__scalar( assert(weights != NULL); assert(packed_weights != NULL); - const float* b = (const float*) bias; - float* packed_w = (float*) packed_weights; + const uint32_t* b = bias; + uint32_t* packed_w = packed_weights; do { // NC main loop multiple of 32 - const float* w = (const float*) weights; + const uint32_t* w = weights; size_t n = nc; for (; n >= 32; n -= 32) { if XNN_LIKELY(b != NULL) { - const uint64_t v0 = ((const uint64_t*)b)[0]; - const uint64_t v1 = ((const uint64_t*)b)[1]; - const uint64_t v2 = ((const uint64_t*)b)[2]; - const uint64_t v3 = ((const uint64_t*)b)[3]; - const uint64_t v4 = ((const uint64_t*)b)[4]; - const uint64_t v5 = ((const uint64_t*)b)[5]; - const uint64_t v6 = ((const uint64_t*)b)[6]; - const uint64_t v7 = ((const uint64_t*)b)[7]; - const uint64_t v8 = ((const uint64_t*)b)[8]; - const uint64_t v9 = ((const uint64_t*)b)[9]; - const uint64_t v10 = ((const uint64_t*)b)[10]; - const uint64_t v11 = ((const uint64_t*)b)[11]; - const uint64_t v12 = ((const uint64_t*)b)[12]; - const uint64_t v13 = ((const uint64_t*)b)[13]; - const uint64_t v14 = ((const uint64_t*)b)[14]; - const uint64_t v15 = ((const uint64_t*)b)[15]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; - ((uint64_t*)packed_w)[2] = v2; - ((uint64_t*)packed_w)[3] = v3; - ((uint64_t*)packed_w)[4] = v4; - ((uint64_t*)packed_w)[5] = v5; - ((uint64_t*)packed_w)[6] = v6; - ((uint64_t*)packed_w)[7] = v7; - ((uint64_t*)packed_w)[8] = v8; - ((uint64_t*)packed_w)[9] = v9; - ((uint64_t*)packed_w)[10] = v10; - ((uint64_t*)packed_w)[11] = v11; - ((uint64_t*)packed_w)[12] = v12; - ((uint64_t*)packed_w)[13] = v13; - ((uint64_t*)packed_w)[14] = v14; - ((uint64_t*)packed_w)[15] = v15; + const uint32_t v0 = b[0]; + const uint32_t v1 = b[1]; + const uint32_t v2 = b[2]; + const uint32_t v3 = b[3]; + const uint32_t v4 = b[4]; + const uint32_t v5 = b[5]; + const uint32_t v6 = b[6]; + const uint32_t v7 = b[7]; + const uint32_t v8 = b[8]; + const uint32_t v9 = b[9]; + const uint32_t v10 = b[10]; + const uint32_t v11 = b[11]; + const uint32_t v12 = b[12]; + const uint32_t v13 = b[13]; + const uint32_t v14 = b[14]; + const uint32_t v15 = b[15]; + const uint32_t v16 = b[16]; + const uint32_t v17 = b[17]; + const uint32_t v18 = b[18]; + const uint32_t v19 = b[19]; + const uint32_t v20 = b[20]; + const uint32_t v21 = b[21]; + const uint32_t v22 = b[22]; + const uint32_t v23 = b[23]; + const uint32_t v24 = b[24]; + const uint32_t v25 = b[25]; + const uint32_t v26 = b[26]; + const uint32_t v27 = b[27]; + const uint32_t v28 = b[28]; + const uint32_t v29 = b[29]; + const uint32_t v30 = b[30]; + const uint32_t v31 = b[31]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; + packed_w[4] = v4; + packed_w[5] = v5; + packed_w[6] = v6; + packed_w[7] = v7; + packed_w[8] = v8; + packed_w[9] = v9; + packed_w[10] = v10; + packed_w[11] = v11; + packed_w[12] = v12; + packed_w[13] = v13; + packed_w[14] = v14; + packed_w[15] = v15; + packed_w[16] = v16; + packed_w[17] = v17; + packed_w[18] = v18; + packed_w[19] = v19; + packed_w[20] = v20; + packed_w[21] = v21; + packed_w[22] = v22; + packed_w[23] = v23; + packed_w[24] = v24; + packed_w[25] = v25; + packed_w[26] = v26; + packed_w[27] = v27; + packed_w[28] = v28; + packed_w[29] = v29; + packed_w[30] = v30; + packed_w[31] = v31; b += 32; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; - ((uint64_t*)packed_w)[2] = 0; - ((uint64_t*)packed_w)[3] = 0; - ((uint64_t*)packed_w)[4] = 0; - ((uint64_t*)packed_w)[5] = 0; - ((uint64_t*)packed_w)[6] = 0; - ((uint64_t*)packed_w)[7] = 0; - ((uint64_t*)packed_w)[8] = 0; - ((uint64_t*)packed_w)[9] = 0; - ((uint64_t*)packed_w)[10] = 0; - ((uint64_t*)packed_w)[11] = 0; - ((uint64_t*)packed_w)[12] = 0; - ((uint64_t*)packed_w)[13] = 0; - ((uint64_t*)packed_w)[14] = 0; - ((uint64_t*)packed_w)[15] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; + packed_w[4] = 0; + packed_w[5] = 0; + packed_w[6] = 0; + packed_w[7] = 0; + packed_w[8] = 0; + packed_w[9] = 0; + packed_w[10] = 0; + packed_w[11] = 0; + packed_w[12] = 0; + packed_w[13] = 0; + packed_w[14] = 0; + packed_w[15] = 0; + packed_w[16] = 0; + packed_w[17] = 0; + packed_w[18] = 0; + packed_w[19] = 0; + packed_w[20] = 0; + packed_w[21] = 0; + packed_w[22] = 0; + packed_w[23] = 0; + packed_w[24] = 0; + packed_w[25] = 0; + packed_w[26] = 0; + packed_w[27] = 0; + packed_w[28] = 0; + packed_w[29] = 0; + packed_w[30] = 0; + packed_w[31] = 0; } packed_w += 32; // KC main loop for (size_t k = kc; k > 0; --k) { - const uint64_t v0 = ((const uint64_t*)w)[0]; - const uint64_t v1 = ((const uint64_t*)w)[1]; - const uint64_t v2 = ((const uint64_t*)w)[2]; - const uint64_t v3 = ((const uint64_t*)w)[3]; - const uint64_t v4 = ((const uint64_t*)w)[4]; - const uint64_t v5 = ((const uint64_t*)w)[5]; - const uint64_t v6 = ((const uint64_t*)w)[6]; - const uint64_t v7 = ((const uint64_t*)w)[7]; - const uint64_t v8 = ((const uint64_t*)w)[8]; - const uint64_t v9 = ((const uint64_t*)w)[9]; - const uint64_t v10 = ((const uint64_t*)w)[10]; - const uint64_t v11 = ((const uint64_t*)w)[11]; - const uint64_t v12 = ((const uint64_t*)w)[12]; - const uint64_t v13 = ((const uint64_t*)w)[13]; - const uint64_t v14 = ((const uint64_t*)w)[14]; - const uint64_t v15 = ((const uint64_t*)w)[15]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; - ((uint64_t*)packed_w)[2] = v2; - ((uint64_t*)packed_w)[3] = v3; - ((uint64_t*)packed_w)[4] = v4; - ((uint64_t*)packed_w)[5] = v5; - ((uint64_t*)packed_w)[6] = v6; - ((uint64_t*)packed_w)[7] = v7; - ((uint64_t*)packed_w)[8] = v8; - ((uint64_t*)packed_w)[9] = v9; - ((uint64_t*)packed_w)[10] = v10; - ((uint64_t*)packed_w)[11] = v11; - ((uint64_t*)packed_w)[12] = v12; - ((uint64_t*)packed_w)[13] = v13; - ((uint64_t*)packed_w)[14] = v14; - ((uint64_t*)packed_w)[15] = v15; + const uint32_t v0 = w[0]; + const uint32_t v1 = w[1]; + const uint32_t v2 = w[2]; + const uint32_t v3 = w[3]; + const uint32_t v4 = w[4]; + const uint32_t v5 = w[5]; + const uint32_t v6 = w[6]; + const uint32_t v7 = w[7]; + const uint32_t v8 = w[8]; + const uint32_t v9 = w[9]; + const uint32_t v10 = w[10]; + const uint32_t v11 = w[11]; + const uint32_t v12 = w[12]; + const uint32_t v13 = w[13]; + const uint32_t v14 = w[14]; + const uint32_t v15 = w[15]; + const uint32_t v16 = w[16]; + const uint32_t v17 = w[17]; + const uint32_t v18 = w[18]; + const uint32_t v19 = w[19]; + const uint32_t v20 = w[20]; + const uint32_t v21 = w[21]; + const uint32_t v22 = w[22]; + const uint32_t v23 = w[23]; + const uint32_t v24 = w[24]; + const uint32_t v25 = w[25]; + const uint32_t v26 = w[26]; + const uint32_t v27 = w[27]; + const uint32_t v28 = w[28]; + const uint32_t v29 = w[29]; + const uint32_t v30 = w[30]; + const uint32_t v31 = w[31]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; + packed_w[4] = v4; + packed_w[5] = v5; + packed_w[6] = v6; + packed_w[7] = v7; + packed_w[8] = v8; + packed_w[9] = v9; + packed_w[10] = v10; + packed_w[11] = v11; + packed_w[12] = v12; + packed_w[13] = v13; + packed_w[14] = v14; + packed_w[15] = v15; + packed_w[16] = v16; + packed_w[17] = v17; + packed_w[18] = v18; + packed_w[19] = v19; + packed_w[20] = v20; + packed_w[21] = v21; + packed_w[22] = v22; + packed_w[23] = v23; + packed_w[24] = v24; + packed_w[25] = v25; + packed_w[26] = v26; + packed_w[27] = v27; + packed_w[28] = v28; + packed_w[29] = v29; + packed_w[30] = v30; + packed_w[31] = v31; w += k_stride; packed_w += 32; } - w = w - kc * k_stride + 32; // Advance to next column of 32 floats + w = w - kc * k_stride + 32; // Advance to next column of 32 uint32_t } // NC remainder (1..31) @@ -152,22 +232,38 @@ void xnn_x32_packw_gemm_gio_ukernel_x32__scalar( } b += n; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; - ((uint64_t*)packed_w)[2] = 0; - ((uint64_t*)packed_w)[3] = 0; - ((uint64_t*)packed_w)[4] = 0; - ((uint64_t*)packed_w)[5] = 0; - ((uint64_t*)packed_w)[6] = 0; - ((uint64_t*)packed_w)[7] = 0; - ((uint64_t*)packed_w)[8] = 0; - ((uint64_t*)packed_w)[9] = 0; - ((uint64_t*)packed_w)[10] = 0; - ((uint64_t*)packed_w)[11] = 0; - ((uint64_t*)packed_w)[12] = 0; - ((uint64_t*)packed_w)[13] = 0; - ((uint64_t*)packed_w)[14] = 0; - ((uint64_t*)packed_w)[15] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; + packed_w[4] = 0; + packed_w[5] = 0; + packed_w[6] = 0; + packed_w[7] = 0; + packed_w[8] = 0; + packed_w[9] = 0; + packed_w[10] = 0; + packed_w[11] = 0; + packed_w[12] = 0; + packed_w[13] = 0; + packed_w[14] = 0; + packed_w[15] = 0; + packed_w[16] = 0; + packed_w[17] = 0; + packed_w[18] = 0; + packed_w[19] = 0; + packed_w[20] = 0; + packed_w[21] = 0; + packed_w[22] = 0; + packed_w[23] = 0; + packed_w[24] = 0; + packed_w[25] = 0; + packed_w[26] = 0; + packed_w[27] = 0; + packed_w[28] = 0; + packed_w[29] = 0; + packed_w[30] = 0; + packed_w[31] = 0; } packed_w += 32; diff --git a/src/x32-packw/gen/x32-packw-x4-gemm-gio-scalar.c b/src/x32-packw/gen/x32-packw-x4-gemm-gio-scalar.c index 2b2253de7ce8..e3e6b299ee0c 100644 --- a/src/x32-packw/gen/x32-packw-x4-gemm-gio-scalar.c +++ b/src/x32-packw/gen/x32-packw-x4-gemm-gio-scalar.c @@ -39,36 +39,46 @@ void xnn_x32_packw_gemm_gio_ukernel_x4__scalar( assert(weights != NULL); assert(packed_weights != NULL); - const float* b = (const float*) bias; - float* packed_w = (float*) packed_weights; + const uint32_t* b = bias; + uint32_t* packed_w = packed_weights; do { // NC main loop multiple of 4 - const float* w = (const float*) weights; + const uint32_t* w = weights; size_t n = nc; for (; n >= 4; n -= 4) { if XNN_LIKELY(b != NULL) { - const uint64_t v0 = ((const uint64_t*)b)[0]; - const uint64_t v1 = ((const uint64_t*)b)[1]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; + const uint32_t v0 = b[0]; + const uint32_t v1 = b[1]; + const uint32_t v2 = b[2]; + const uint32_t v3 = b[3]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; b += 4; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; } packed_w += 4; // KC main loop for (size_t k = kc; k > 0; --k) { - const uint64_t v0 = ((const uint64_t*)w)[0]; - const uint64_t v1 = ((const uint64_t*)w)[1]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; + const uint32_t v0 = w[0]; + const uint32_t v1 = w[1]; + const uint32_t v2 = w[2]; + const uint32_t v3 = w[3]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; w += k_stride; packed_w += 4; } - w = w - kc * k_stride + 4; // Advance to next column of 4 floats + w = w - kc * k_stride + 4; // Advance to next column of 4 uint32_t } // NC remainder (1..3) @@ -82,8 +92,10 @@ void xnn_x32_packw_gemm_gio_ukernel_x4__scalar( } b += n; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; } packed_w += 4; diff --git a/src/x32-packw/gen/x32-packw-x8-gemm-gio-scalar.c b/src/x32-packw/gen/x32-packw-x8-gemm-gio-scalar.c index d7418f09245d..dc14ae91b3b4 100644 --- a/src/x32-packw/gen/x32-packw-x8-gemm-gio-scalar.c +++ b/src/x32-packw/gen/x32-packw-x8-gemm-gio-scalar.c @@ -39,46 +39,66 @@ void xnn_x32_packw_gemm_gio_ukernel_x8__scalar( assert(weights != NULL); assert(packed_weights != NULL); - const float* b = (const float*) bias; - float* packed_w = (float*) packed_weights; + const uint32_t* b = bias; + uint32_t* packed_w = packed_weights; do { // NC main loop multiple of 8 - const float* w = (const float*) weights; + const uint32_t* w = weights; size_t n = nc; for (; n >= 8; n -= 8) { if XNN_LIKELY(b != NULL) { - const uint64_t v0 = ((const uint64_t*)b)[0]; - const uint64_t v1 = ((const uint64_t*)b)[1]; - const uint64_t v2 = ((const uint64_t*)b)[2]; - const uint64_t v3 = ((const uint64_t*)b)[3]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; - ((uint64_t*)packed_w)[2] = v2; - ((uint64_t*)packed_w)[3] = v3; + const uint32_t v0 = b[0]; + const uint32_t v1 = b[1]; + const uint32_t v2 = b[2]; + const uint32_t v3 = b[3]; + const uint32_t v4 = b[4]; + const uint32_t v5 = b[5]; + const uint32_t v6 = b[6]; + const uint32_t v7 = b[7]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; + packed_w[4] = v4; + packed_w[5] = v5; + packed_w[6] = v6; + packed_w[7] = v7; b += 8; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; - ((uint64_t*)packed_w)[2] = 0; - ((uint64_t*)packed_w)[3] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; + packed_w[4] = 0; + packed_w[5] = 0; + packed_w[6] = 0; + packed_w[7] = 0; } packed_w += 8; // KC main loop for (size_t k = kc; k > 0; --k) { - const uint64_t v0 = ((const uint64_t*)w)[0]; - const uint64_t v1 = ((const uint64_t*)w)[1]; - const uint64_t v2 = ((const uint64_t*)w)[2]; - const uint64_t v3 = ((const uint64_t*)w)[3]; - ((uint64_t*)packed_w)[0] = v0; - ((uint64_t*)packed_w)[1] = v1; - ((uint64_t*)packed_w)[2] = v2; - ((uint64_t*)packed_w)[3] = v3; + const uint32_t v0 = w[0]; + const uint32_t v1 = w[1]; + const uint32_t v2 = w[2]; + const uint32_t v3 = w[3]; + const uint32_t v4 = w[4]; + const uint32_t v5 = w[5]; + const uint32_t v6 = w[6]; + const uint32_t v7 = w[7]; + packed_w[0] = v0; + packed_w[1] = v1; + packed_w[2] = v2; + packed_w[3] = v3; + packed_w[4] = v4; + packed_w[5] = v5; + packed_w[6] = v6; + packed_w[7] = v7; w += k_stride; packed_w += 8; } - w = w - kc * k_stride + 8; // Advance to next column of 8 floats + w = w - kc * k_stride + 8; // Advance to next column of 8 uint32_t } // NC remainder (1..7) @@ -92,10 +112,14 @@ void xnn_x32_packw_gemm_gio_ukernel_x8__scalar( } b += n; } else { - ((uint64_t*)packed_w)[0] = 0; - ((uint64_t*)packed_w)[1] = 0; - ((uint64_t*)packed_w)[2] = 0; - ((uint64_t*)packed_w)[3] = 0; + packed_w[0] = 0; + packed_w[1] = 0; + packed_w[2] = 0; + packed_w[3] = 0; + packed_w[4] = 0; + packed_w[5] = 0; + packed_w[6] = 0; + packed_w[7] = 0; } packed_w += 8; diff --git a/src/x32-packw/gio-scalar.c.in b/src/x32-packw/gio-scalar.c.in index 64f97fbaa9ed..dbf8a7951c19 100644 --- a/src/x32-packw/gio-scalar.c.in +++ b/src/x32-packw/gio-scalar.c.in @@ -35,36 +35,36 @@ void xnn_x32_packw_gemm_gio_ukernel_x${NR}__scalar( assert(weights != NULL); assert(packed_weights != NULL); - const float* b = (const float*) bias; - float* packed_w = (float*) packed_weights; + const uint32_t* b = bias; + uint32_t* packed_w = packed_weights; do { // NC main loop multiple of ${NR} - const float* w = (const float*) weights; + const uint32_t* w = weights; size_t n = nc; for (; n >= ${NR}; n -= ${NR}) { if XNN_LIKELY(b != NULL) { - $for N in range(0,NR,2): - const uint64_t v${N//2} = ((const uint64_t*)b)[${N//2}]; - $for N in range(0,NR,2): - ((uint64_t*)packed_w)[${N//2}] = v${N//2}; + $for N in range(NR): + const uint32_t v${N} = b[${N}]; + $for N in range(NR): + packed_w[${N}] = v${N}; b += ${NR}; } else { - $for N in range(0,NR,2): - ((uint64_t*)packed_w)[${N//2}] = 0; + $for N in range(NR): + packed_w[${N}] = 0; } packed_w += ${NR}; // KC main loop for (size_t k = kc; k > 0; --k) { - $for N in range(0,NR,2): - const uint64_t v${N//2} = ((const uint64_t*)w)[${N//2}]; - $for N in range(0,NR,2): - ((uint64_t*)packed_w)[${N//2}] = v${N//2}; + $for N in range(NR): + const uint32_t v${N} = w[${N}]; + $for N in range(NR): + packed_w[${N}] = v${N}; w += k_stride; packed_w += ${NR}; } - w = w - kc * k_stride + ${NR}; // Advance to next column of ${NR} floats + w = w - kc * k_stride + ${NR}; // Advance to next column of ${NR} uint32_t } // NC remainder (1..${NR-1}) @@ -78,8 +78,8 @@ void xnn_x32_packw_gemm_gio_ukernel_x${NR}__scalar( } b += n; } else { - $for N in range(0,NR,2): - ((uint64_t*)packed_w)[${N//2}] = 0; + $for N in range(NR): + packed_w[${N}] = 0; } packed_w += ${NR}; diff --git a/src/x32-packw/gio-simd.c.in b/src/x32-packw/gio-simd.c.in new file mode 100644 index 000000000000..ce761ce4abf8 --- /dev/null +++ b/src/x32-packw/gio-simd.c.in @@ -0,0 +1,149 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$BATCH_TILES = tuple(int(bt) for bt in BATCH_TILES.split(",")) +$SIMD_SIZE = BATCH_TILES[0] + +#include +#include +#include + +#include "xnnpack/simd/s32-${ARCH}.h" + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" +$if PREFETCH: + #include "xnnpack/prefetch.h" + +static XNN_INLINE xnn_simd_s32_t +xnn_load_tail_no_oob_s32(const int32_t* input, size_t num_elements) { + assert(num_elements <= xnn_simd_size_s32); + int32_t buf[${SIMD_SIZE}]; + for (size_t i = 0; i < num_elements; ++i) { + buf[i] = input[i]; + } + return xnn_loadu_s32((const int32_t*) &buf[0]); +} + +$for NR in BATCH_TILES: + $SIMD_TILE = NR // SIMD_SIZE + + // Pack pre-transposed weights (GIO) for use by f32-gemm + void xnn_x32_packw_gemm_gio_ukernel_x${NR}__${ARCH}_u${KBLOCK}${"_prfm" if PREFETCH else ""}( + size_t g, // Batch size (outer loop). usually 1 + size_t nc, // Number of columns and typically large + size_t kc, // Number of rows and typically small + size_t nr, // Matches gemm and is a multiple of vector sizes + size_t kr, // unused - must be 1 + size_t sr, // unused - must be 1 + size_t k_stride, // Elements per row (typically same as nc) + const uint32_t* weights, // Weights to pack. unaligned, unpadded + const uint32_t* bias, // Bias to pack. unaligned, unpadded, can be NULL + const void* scale, // unused + uint32_t* packed_weights, // packed weights output buffer - aligned, padded + size_t extra_bytes, // number of extra bytes between weights. aligned + const void* params) // unused + { + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); // This kernel is for NR=${NR} + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + const xnn_simd_s32_t vzero = xnn_set1_s32(0); + const int32_t* b = (const int32_t*) bias; + int32_t* packed_w = (int32_t*) packed_weights; + do { + // NC main loop multiple of ${NR} + const int32_t* w = (const int32_t*) weights; + size_t n = nc; + + for (; n >= ${NR}; n -= ${NR}) { + if XNN_LIKELY(b != NULL) { + $for N in range(SIMD_TILE): + const xnn_simd_s32_t vb${N} = xnn_loadu_s32(b + ${N*SIMD_SIZE}); + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vb${N}); + b += ${NR}; + } else { + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vzero); + } + packed_w += ${NR}; + + // KC main loop ${KBLOCK}x${NR} + size_t k = kc; + $if KBLOCK > 1: + for (; k >= ${KBLOCK}; k -= ${KBLOCK}) { + $for K in range(KBLOCK): + $for N in range(SIMD_TILE): + const xnn_simd_s32_t v${N}_${K} = xnn_loadu_s32(w + ${N*SIMD_SIZE} + ${K} * k_stride); + $if PREFETCH: + $for K in range(KBLOCK): + $for N in range(SIMD_TILE): + xnn_prefetch_to_l1((const int8_t*) w + ${N*4+960}); + $for K in range(KBLOCK): + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE+K*NR}, v${N}_${K}); + w += k_stride * ${KBLOCK}; + packed_w += ${NR*KBLOCK}; + } + + // KC remainder loop + for (; k > 0; --k) { + $for N in range(SIMD_TILE): + const xnn_simd_s32_t v${N} = xnn_loadu_s32(w + ${N*SIMD_SIZE}); + $if PREFETCH: + $for N in range(SIMD_TILE): + xnn_prefetch_to_l1((const int8_t*) w + ${N*4+960}); + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, v${N}); + w += k_stride; + packed_w += ${NR}; + } + w = w - kc * k_stride + ${NR}; // Advance to next column of ${NR} int32_t + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= ${NR-1}); + + // Prepare count for valid 32-bit elements (depends on n). + $for N in range(SIMD_TILE): + $if SIMD_TILE == 1: + const size_t vcount0 = n; + $else: + const size_t vcount${N} = (int) (n - ${N*SIMD_SIZE}) < 0 ? 0 : ((int) (n - ${N*SIMD_SIZE}) > ${SIMD_SIZE} ? ${SIMD_SIZE} : n - ${N*SIMD_SIZE}); + + if XNN_LIKELY(b != NULL) { + $for N in range(SIMD_TILE): + const xnn_simd_s32_t vb${N} = xnn_load_tail_no_oob_s32(b + ${N*SIMD_SIZE}, vcount${N}); + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vb${N}); + b += n; + } else { + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, vzero); + } + packed_w += ${NR}; + + // KC main loop + for (size_t k = kc; k > 0; --k) { + $for N in range(SIMD_TILE): + const xnn_simd_s32_t v${N} = xnn_load_tail_no_oob_s32(w + ${N*SIMD_SIZE}, vcount${N}); + $for N in range(SIMD_TILE): + xnn_storeu_s32(packed_w + ${N*SIMD_SIZE}, v${N}); + w += k_stride; + packed_w += ${NR}; + } + } + weights += nc * kc; + } while (--g != 0); + } diff --git a/src/x32-packw/x32-packw.h b/src/x32-packw/x32-packw.h index b54a22b21237..be1dab451400 100644 --- a/src/x32-packw/x32-packw.h +++ b/src/x32-packw/x32-packw.h @@ -36,6 +36,11 @@ XNN_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_goi_ukernel_x16__neon_ld4lane_ XNN_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_goi_ukernel_x16__neon_ld4lane_u4_prfm, 16, 1, 1, 4, 1) XNN_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_goi_ukernel_x16__neon_ld4lane_u8, 16, 1, 1, 8, 1) XNN_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_goi_ukernel_x16__neon_ld4lane_u8_prfm, 16, 1, 1, 8, 1) + +XNN_GIO_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_gio_ukernel_x4__neon_u2, 4, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_gio_ukernel_x8__neon_u2, 8, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_gio_ukernel_x12__neon_u2, 12, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_arm_neon, xnn_x32_packw_gemm_gio_ukernel_x16__neon_u2, 16, 1, 1, 1, 1) #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #if XNN_ARCH_X86 || XNN_ARCH_X86_64 @@ -78,6 +83,11 @@ XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8, 16 XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8_prfm, 16, 1, 1, 8, 1) XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x32__avx_u8, 32, 1, 1, 8, 1) XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x32__avx_u8_prfm, 32, 1, 1, 8, 1) + +XNN_GIO_UKERNEL(xnn_arch_x86_sse4_1, xnn_x32_packw_gemm_gio_ukernel_x4__sse41_u2, 4, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_x86_sse4_1, xnn_x32_packw_gemm_gio_ukernel_x8__sse41_u2, 8, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_x86_sse4_1, xnn_x32_packw_gemm_gio_ukernel_x12__sse41_u2, 12, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_x86_sse4_1, xnn_x32_packw_gemm_gio_ukernel_x16__sse41_u2, 16, 1, 1, 1, 1) #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 XNN_GIO_UKERNEL(0, xnn_x32_packw_gemm_gio_ukernel_x4__scalar, 4, 1, 1, 1, 1) @@ -101,11 +111,15 @@ XNN_GIO_UKERNEL(xnn_arch_x86_avx512f, xnn_x32_packw_gemm_gio_ukernel_x32__avx512 XNN_GIO_UKERNEL(xnn_arch_x86_avx512f, xnn_x32_packw_gemm_gio_ukernel_x32__avx512f_u8_prfm, 32, 1, 1, 8, 1) #endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86_64 || XNN_ARCH_X86) - #if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD XNN_UKERNEL(0, xnn_x32_packw_gemm_goi_ukernel_x2c4__wasmsimd_u4, 2, 4, 1, 4, 1) XNN_UKERNEL(0, xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4, 8, 1, 1, 4, 1) XNN_UKERNEL(0, xnn_x32_packw_gemm_goi_ukernel_x8s4__wasmsimd_u4, 8, 1, 4, 4, 1) + +XNN_GIO_UKERNEL(0, xnn_x32_packw_gemm_gio_ukernel_x4__wasmsimd_u2, 4, 1, 1, 1, 1) +XNN_GIO_UKERNEL(0, xnn_x32_packw_gemm_gio_ukernel_x8__wasmsimd_u2, 8, 1, 1, 1, 1) +XNN_GIO_UKERNEL(0, xnn_x32_packw_gemm_gio_ukernel_x12__wasmsimd_u2, 12, 1, 1, 1, 1) +XNN_GIO_UKERNEL(0, xnn_x32_packw_gemm_gio_ukernel_x16__wasmsimd_u2, 16, 1, 1, 1, 1) #endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD XNN_UKERNEL(0, xnn_x32_packw_gemm_goi_ukernel_x2__scalar_float_u4, 2, 1, 1, 4, 1) @@ -134,6 +148,12 @@ XNN_UKERNEL(xnn_arch_riscv_vector, xnn_x32_packw_gemm_goi_ukernel_x8v__rvv_u4, 8 XNN_UKERNEL(xnn_arch_riscv_vector, xnn_x32_packw_gemm_goi_ukernel_x8v__rvv_u8, 8, 1, 1, 8, xnn_init_hardware_config()->vlenb / sizeof(uint32_t)) #endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV() +#if XNN_ENABLE_HVX && (XNN_ARCH_HEXAGON) +XNN_GIO_UKERNEL(xnn_arch_hvx, xnn_x32_packw_gemm_gio_ukernel_x32__hvx_u2, 32, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_hvx, xnn_x32_packw_gemm_gio_ukernel_x64__hvx_u2, 64, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_hvx, xnn_x32_packw_gemm_gio_ukernel_x96__hvx_u2, 96, 1, 1, 1, 1) +XNN_GIO_UKERNEL(xnn_arch_hvx, xnn_x32_packw_gemm_gio_ukernel_x128__hvx_u2, 128, 1, 1, 1, 1) +#endif // XNN_ENABLE_HVX && (XNN_ARCH_HEXAGON) #ifdef XNN_DEFINED_UKERNEL_WITH_PARAMS #undef XNN_DEFINED_UKERNEL_WITH_PARAMS diff --git a/src/x32-zip/x32-zip-x2-neon.c b/src/x32-zip/x32-zip-x2-neon.c deleted file mode 100644 index d56f32fa0464..000000000000 --- a/src/x32-zip/x32-zip-x2-neon.c +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x2_ukernel__neon( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - uint32_t* o = output; - - while (n >= 16) { - uint32x4x2_t vxy; - vxy.val[0] = vld1q_u32(x); x += 4; - vxy.val[1] = vld1q_u32(y); y += 4; - vst2q_u32(o, vxy); o += 8; - n -= 16; - } - if XNN_UNLIKELY(n != 0) { - if (n & 8) { - uint32x2x2_t vxy; - vxy.val[0] = vld1_u32(x); x += 2; - vxy.val[1] = vld1_u32(y); y += 2; - vst2_u32(o, vxy); o += 4; - } - if (n & 4) { - uint32x2_t vxy = vld1_dup_u32(x); - vxy = vld1_lane_u32(y, vxy, 1); - vst1_u32(o, vxy); - } - } -} diff --git a/src/x32-zip/x32-zip-x2-scalar.c b/src/x32-zip/x32-zip-x2-scalar.c deleted file mode 100644 index f6e3c86b1f4f..000000000000 --- a/src/x32-zip/x32-zip-x2-scalar.c +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x2_ukernel__scalar( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - - do { - const uint32_t vx = *x++; - const uint32_t vy = *y++; - output[0] = vx; - output[1] = vy; - output += 2; - - n -= 4; - } while (n != 0); -} diff --git a/src/x32-zip/x32-zip-x2-sse2.c b/src/x32-zip/x32-zip-x2-sse2.c deleted file mode 100644 index 548976c41ff9..000000000000 --- a/src/x32-zip/x32-zip-x2-sse2.c +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x2_ukernel__sse2( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - uint32_t* o = output; - - while (n >= 16) { - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 4; - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 4; - const __m128i vxy_lo = _mm_unpacklo_epi32(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi32(vx, vy); - _mm_storeu_si128((__m128i*) o, vxy_lo); - _mm_storeu_si128((__m128i*) (o + 4), vxy_hi); - o += 8; - n -= 16; - } - if XNN_UNLIKELY(n != 0) { - if (n & 8) { - const __m128i vx = _mm_loadl_epi64((const __m128i*) x); - x += 2; - const __m128i vy = _mm_loadl_epi64((const __m128i*) y); - y += 2; - const __m128i vxy = _mm_unpacklo_epi32(vx, vy); - _mm_storeu_si128((__m128i*) o, vxy); - o += 4; - } - if (n & 4) { - const uint32_t vx = *x; - const uint32_t vy = *y; - o[0] = vx; - o[1] = vy; - } - } -} diff --git a/src/x32-zip/x32-zip-x2-wasmsimd.c b/src/x32-zip/x32-zip-x2-wasmsimd.c deleted file mode 100644 index f2478be361be..000000000000 --- a/src/x32-zip/x32-zip-x2-wasmsimd.c +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2020 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x2_ukernel__wasmsimd( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % sizeof(uint32_t) == 0); - - const float* x = (const float*) input; - const float* y = (const float*) ((uintptr_t) x + n); - float* o = (float*) output; - - while (n >= 4 * sizeof(uint32_t)) { - const v128_t vx = wasm_v128_load(x); - x += 4; - const v128_t vy = wasm_v128_load(y); - y += 4; - const v128_t vxy_lo = wasm_v32x4_shuffle(vx, vy, 0, 4, 1, 5); - const v128_t vxy_hi = wasm_v32x4_shuffle(vx, vy, 2, 6, 3, 7); - wasm_v128_store(o, vxy_lo); - wasm_v128_store(o + 4, vxy_hi); - o += 8; - n -= 4 * sizeof(uint32_t); - } - if XNN_UNLIKELY(n != 0) { - if (n & (2 * sizeof(uint32_t))) { - const double vx = *((const double*) x); - x += 2; - const double vy = *((const double*) y); - y += 2; - const v128_t vxy = wasm_f64x2_make(vx, vy); - wasm_v128_store(o, wasm_v32x4_shuffle(vxy, vxy, 0, 2, 1, 3)); - o += 4; - } - if (n & (1 * sizeof(uint32_t))) { - const float vx = *x; - const float vy = *y; - o[0] = vx; - o[1] = vy; - } - } -} diff --git a/src/x32-zip/x32-zip-x3-neon.c b/src/x32-zip/x32-zip-x3-neon.c deleted file mode 100644 index 8ca3baa1291f..000000000000 --- a/src/x32-zip/x32-zip-x3-neon.c +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x3_ukernel__neon( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n); - uint32_t* o = output; - - while (n >= 16) { - uint32x4x3_t vxyz; - vxyz.val[0] = vld1q_u32(x); x += 4; - vxyz.val[1] = vld1q_u32(y); y += 4; - vxyz.val[2] = vld1q_u32(z); z += 4; - vst3q_u32(o, vxyz); o += 12; - n -= 16; - } - if XNN_UNLIKELY(n != 0) { - if (n & 8) { - uint32x2x3_t vxyz; - vxyz.val[0] = vld1_u32(x); x += 2; - vxyz.val[1] = vld1_u32(y); y += 2; - vxyz.val[2] = vld1_u32(z); z += 2; - vst3_u32(o, vxyz); o += 6; - } - if (n & 4) { - uint32x2_t vxy = vld1_dup_u32(x); - const uint32x2_t vz = vld1_dup_u32(z); - vxy = vld1_lane_u32(y, vxy, 1); - vst1_u32(o, vxy); o += 2; - vst1_lane_u32(o, vz, 0); - } - } -} diff --git a/src/x32-zip/x32-zip-x3-scalar.c b/src/x32-zip/x32-zip-x3-scalar.c deleted file mode 100644 index 9a7cc7a93d99..000000000000 --- a/src/x32-zip/x32-zip-x3-scalar.c +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x3_ukernel__scalar( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n); - uint32_t* o = output; - - do { - const uint32_t vx = *x++; - const uint32_t vy = *y++; - const uint32_t vz = *z++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o += 3; - - n -= 4; - } while (n != 0); -} diff --git a/src/x32-zip/x32-zip-x3-sse2.c b/src/x32-zip/x32-zip-x3-sse2.c deleted file mode 100644 index bef222574731..000000000000 --- a/src/x32-zip/x32-zip-x3-sse2.c +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x3_ukernel__sse2( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const float* x = (const float*) input; - const float* y = (const float*) ((uintptr_t) x + n); - const float* z = (const float*) ((uintptr_t) y + n); - float* o = (float*) output; - - while (n >= 16) { - // vx = ( x3, x2, x1, x0 ) - const __m128 vx = _mm_loadu_ps(x); - x += 4; - // vy = ( y3, y2, y1, y0 ) - const __m128 vy = _mm_loadu_ps(y); - y += 4; - // vz = ( z3, z2, z1, z0 ) - const __m128 vz = _mm_loadu_ps(z); - z += 4; - - // vxy = ( y2, y0, x2, x0 ) - const __m128 vxy = _mm_shuffle_ps(vx, vy, _MM_SHUFFLE(2, 0, 2, 0)); - // vyz = ( z3, z1, y3, y1 ) - const __m128 vyz = _mm_shuffle_ps(vy, vz, _MM_SHUFFLE(3, 1, 3, 1)); - // vzx = ( x3, x1, z2, z0 ) - const __m128 vzx = _mm_shuffle_ps(vz, vx, _MM_SHUFFLE(3, 1, 2, 0)); - - // vxyz0 = ( x1, z0, y0, x0 ) - const __m128 vxyz0 = _mm_shuffle_ps(vxy, vzx, _MM_SHUFFLE(2, 0, 2, 0)); - // vxyz1 = ( y2, x2, z1, y1 ) - const __m128 vxyz1 = _mm_shuffle_ps(vyz, vxy, _MM_SHUFFLE(3, 1, 2, 0)); - // vxyz2 = ( z3, y3, x3, z2 ) - const __m128 vxyz2 = _mm_shuffle_ps(vzx, vyz, _MM_SHUFFLE(3, 1, 3, 1)); - - _mm_storeu_ps(o, vxyz0); - _mm_storeu_ps(o + 4, vxyz1); - _mm_storeu_ps(o + 8, vxyz2); - o += 12; - n -= 16; - } - if XNN_UNLIKELY(n != 0) { - if (n & 8) { - // vx = ( -, -, x1, x0 ) - const __m128 vx = _mm_castpd_ps(_mm_load_sd((const double*) x)); - x += 2; - // vy = ( -, -, y1, y0 ) - const __m128 vy = _mm_castpd_ps(_mm_load_sd((const double*) y)); - y += 2; - // vz = ( -, -, z1, z0 ) - const __m128 vz = _mm_castpd_ps(_mm_load_sd((const double*) z)); - z += 2; - - // vxy = ( y1, x1, y0, x0 ) - const __m128 vxy = _mm_unpacklo_ps(vx, vy); - // vzx = ( x1, z1, x0, z0 ) - const __m128 vzx = _mm_unpacklo_ps(vz, vx); - // vyz = ( z1, y1, z0, y0 ) - const __m128 vyz = _mm_unpacklo_ps(vy, vz); - - _mm_storeu_ps(o, _mm_shuffle_ps(vxy, vzx, _MM_SHUFFLE(3, 0, 1, 0))); - _mm_storeh_pi((__m64*) (o + 4), vyz); - o += 6; - } - if (n & 4) { - const __m128 vx = _mm_load_ss(x); - const __m128 vy = _mm_load_ss(y); - const __m128 vz = _mm_load_ss(z); - _mm_store_ss(o, vx); - _mm_store_ss(o + 1, vy); - _mm_store_ss(o + 2, vz); - } - } -} diff --git a/src/x32-zip/x32-zip-x3-wasmsimd.c b/src/x32-zip/x32-zip-x3-wasmsimd.c deleted file mode 100644 index 3aac4b632c75..000000000000 --- a/src/x32-zip/x32-zip-x3-wasmsimd.c +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x3_ukernel__wasmsimd( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % sizeof(uint32_t) == 0); - - const float* x = (const float*) input; - const float* y = (const float*) ((uintptr_t) x + n); - const float* z = (const float*) ((uintptr_t) y + n); - float* o = (float*) output; - - while (n >= 4 * sizeof(uint32_t)) { - // vx = ( x3, x2, x1, x0 ) - const v128_t vx = wasm_v128_load(x); - x += 4; - // vy = ( y3, y2, y1, y0 ) - const v128_t vy = wasm_v128_load(y); - y += 4; - // vz = ( z3, z2, z1, z0 ) - const v128_t vz = wasm_v128_load(z); - z += 4; - - // vxy = ( y2, y0, x2, x0 ) - const v128_t vxy = wasm_v32x4_shuffle(vx, vy, 0, 2, 4, 6); - // vyz = ( z3, z1, y3, y1 ) - const v128_t vyz = wasm_v32x4_shuffle(vy, vz, 1, 3, 5, 7); - // vzx = ( x3, x1, z2, z0 ) - const v128_t vzx = wasm_v32x4_shuffle(vz, vx, 0, 2, 5, 7); - - // vxyz0 = ( x1, z0, y0, x0 ) - const v128_t vxyz0 = wasm_v32x4_shuffle(vxy, vzx, 0, 2, 4, 6); - // vxyz1 = ( y2, x2, z1, y1 ) - const v128_t vxyz1 = wasm_v32x4_shuffle(vyz, vxy, 0, 2, 5, 7); - // vxyz2 = ( z3, y3, x3, z2 ) - const v128_t vxyz2 = wasm_v32x4_shuffle(vzx, vyz, 1, 3, 5, 7); - - wasm_v128_store(o, vxyz0); - wasm_v128_store(o + 4, vxyz1); - wasm_v128_store(o + 8, vxyz2); - o += 12; - n -= 4 * sizeof(uint32_t); - } - if XNN_UNLIKELY(n != 0) { - do { - const float vx = *x++; - const float vy = *y++; - const float vz = *z++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o += 3; - n -= sizeof(uint32_t); - } while (n != 0); - } -} diff --git a/src/x32-zip/x32-zip-x4-neon.c b/src/x32-zip/x32-zip-x4-neon.c deleted file mode 100644 index ef9f54b5afa5..000000000000 --- a/src/x32-zip/x32-zip-x4-neon.c +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x4_ukernel__neon( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n); - const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n); - uint32_t* o = output; - - while (n >= 16) { - uint32x4x4_t vxyzw; - vxyzw.val[0] = vld1q_u32(x); x += 4; - vxyzw.val[1] = vld1q_u32(y); y += 4; - vxyzw.val[2] = vld1q_u32(z); z += 4; - vxyzw.val[3] = vld1q_u32(w); w += 4; - vst4q_u32(o, vxyzw); o += 16; - n -= 16; - } - if XNN_UNLIKELY(n != 0) { - if (n & 8) { - uint32x2x4_t vxyzw; - vxyzw.val[0] = vld1_u32(x); x += 2; - vxyzw.val[1] = vld1_u32(y); y += 2; - vxyzw.val[2] = vld1_u32(z); z += 2; - vxyzw.val[3] = vld1_u32(w); w += 2; - vst4_u32(o, vxyzw); o += 8; - } - if (n & 4) { - uint32x4_t vxyzw = vld1q_dup_u32(x); - vxyzw = vld1q_lane_u32(y, vxyzw, 1); - vxyzw = vld1q_lane_u32(z, vxyzw, 2); - vxyzw = vld1q_lane_u32(w, vxyzw, 3); - vst1q_u32(o, vxyzw); - } - } -} diff --git a/src/x32-zip/x32-zip-x4-scalar.c b/src/x32-zip/x32-zip-x4-scalar.c deleted file mode 100644 index 73b36443c1b4..000000000000 --- a/src/x32-zip/x32-zip-x4-scalar.c +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x4_ukernel__scalar( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n); - const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n); - uint32_t* o = output; - - do { - const uint32_t vx = *x++; - const uint32_t vy = *y++; - const uint32_t vz = *z++; - const uint32_t vw = *w++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - o += 4; - - n -= 4; - } while (n != 0); -} diff --git a/src/x32-zip/x32-zip-x4-sse2.c b/src/x32-zip/x32-zip-x4-sse2.c deleted file mode 100644 index 82245e7b8d87..000000000000 --- a/src/x32-zip/x32-zip-x4-sse2.c +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x4_ukernel__sse2( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - - const uint32_t* x = input; - const uint32_t* y = (const uint32_t*) ((uintptr_t) x + n); - const uint32_t* z = (const uint32_t*) ((uintptr_t) y + n); - const uint32_t* w = (const uint32_t*) ((uintptr_t) z + n); - uint32_t* o = output; - - while (n >= 16) { - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 4; - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 4; - const __m128i vz = _mm_loadu_si128((const __m128i*) z); - z += 4; - const __m128i vw = _mm_loadu_si128((const __m128i*) w); - w += 4; - - const __m128i vxy_lo = _mm_unpacklo_epi32(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi32(vx, vy); - const __m128i vzw_lo = _mm_unpacklo_epi32(vz, vw); - const __m128i vzw_hi = _mm_unpackhi_epi32(vz, vw); - - const __m128i vxyzw0 = _mm_unpacklo_epi64(vxy_lo, vzw_lo); - const __m128i vxyzw1 = _mm_unpackhi_epi64(vxy_lo, vzw_lo); - const __m128i vxyzw2 = _mm_unpacklo_epi64(vxy_hi, vzw_hi); - const __m128i vxyzw3 = _mm_unpackhi_epi64(vxy_hi, vzw_hi); - - _mm_storeu_si128((__m128i*) o, vxyzw0); - _mm_storeu_si128((__m128i*) (o + 4), vxyzw1); - _mm_storeu_si128((__m128i*) (o + 8), vxyzw2); - _mm_storeu_si128((__m128i*) (o + 12), vxyzw3); - o += 16; - n -= 16; - } - if XNN_UNLIKELY(n != 0) { - if (n & 8) { - const __m128i vx = _mm_loadl_epi64((const __m128i*) x); - x += 2; - const __m128i vy = _mm_loadl_epi64((const __m128i*) y); - y += 2; - const __m128i vz = _mm_loadl_epi64((const __m128i*) z); - z += 2; - const __m128i vw = _mm_loadl_epi64((const __m128i*) w); - w += 2; - - const __m128i vxy = _mm_unpacklo_epi32(vx, vy); - const __m128i vzw = _mm_unpacklo_epi32(vz, vw); - - const __m128i vxyzw_lo = _mm_unpacklo_epi64(vxy, vzw); - const __m128i vxyzw_hi = _mm_unpackhi_epi64(vxy, vzw); - - _mm_storeu_si128((__m128i*) o, vxyzw_lo); - _mm_storeu_si128((__m128i*) (o + 4), vxyzw_hi); - o += 8; - } - if (n & 4) { - const uint32_t vx = *x; - const uint32_t vy = *y; - const uint32_t vz = *z; - const uint32_t vw = *w; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - } - } -} diff --git a/src/x32-zip/x32-zip-x4-wasmsimd.c b/src/x32-zip/x32-zip-x4-wasmsimd.c deleted file mode 100644 index 74dd35994617..000000000000 --- a/src/x32-zip/x32-zip-x4-wasmsimd.c +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2020 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_x4_ukernel__wasmsimd( - size_t n, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % sizeof(uint32_t) == 0); - - const float* x = (const float*) input; - const float* y = (const float*) ((uintptr_t) x + n); - const float* z = (const float*) ((uintptr_t) y + n); - const float* w = (const float*) ((uintptr_t) z + n); - float* o = (float*) output; - - while (n >= 4 * sizeof(uint32_t)) { - const v128_t vx = wasm_v128_load(x); - x += 4; - const v128_t vy = wasm_v128_load(y); - y += 4; - const v128_t vz = wasm_v128_load(z); - z += 4; - const v128_t vw = wasm_v128_load(w); - w += 4; - - const v128_t vxy_lo = wasm_v32x4_shuffle(vx, vy, 0, 4, 1, 5); - const v128_t vxy_hi = wasm_v32x4_shuffle(vx, vy, 2, 6, 3, 7); - const v128_t vzw_lo = wasm_v32x4_shuffle(vz, vw, 0, 4, 1, 5); - const v128_t vzw_hi = wasm_v32x4_shuffle(vz, vw, 2, 6, 3, 7); - - const v128_t vxyzw0 = wasm_v32x4_shuffle(vxy_lo, vzw_lo, 0, 1, 4, 5); - const v128_t vxyzw1 = wasm_v32x4_shuffle(vxy_lo, vzw_lo, 2, 3, 6, 7); - const v128_t vxyzw2 = wasm_v32x4_shuffle(vxy_hi, vzw_hi, 0, 1, 4, 5); - const v128_t vxyzw3 = wasm_v32x4_shuffle(vxy_hi, vzw_hi, 2, 3, 6, 7); - - wasm_v128_store(o, vxyzw0); - wasm_v128_store(o + 4, vxyzw1); - wasm_v128_store(o + 8, vxyzw2); - wasm_v128_store(o + 12, vxyzw3); - o += 16; - n -= 4 * sizeof(uint32_t); - } - if XNN_UNLIKELY(n != 0) { - if (n & (2 * sizeof(uint32_t))) { - const double vx = *((const double*) x); - x += 2; - const double vy = *((const double*) y); - y += 2; - const double vz = *((const double*) z); - z += 2; - const double vw = *((const double*) w); - w += 2; - - const v128_t vxy = wasm_f64x2_make(vx, vy); - const v128_t vzw = wasm_f64x2_make(vz, vw); - - const v128_t vxyzw_lo = wasm_v32x4_shuffle(vxy, vzw, 0, 2, 4, 6); - const v128_t vxyzw_hi = wasm_v32x4_shuffle(vxy, vzw, 1, 3, 5, 7); - - wasm_v128_store(o, vxyzw_lo); - wasm_v128_store(o + 4, vxyzw_hi); - o += 8; - } - if (n & (1 * sizeof(uint32_t))) { - const float vx = *x; - const float vy = *y; - const float vz = *z; - const float vw = *w; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - } - } -} diff --git a/src/x32-zip/x32-zip-xm-neon.c b/src/x32-zip/x32-zip-xm-neon.c deleted file mode 100644 index 13c56e6ced7b..000000000000 --- a/src/x32-zip/x32-zip-xm-neon.c +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_xm_ukernel__neon( - size_t n, - size_t m, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - assert(m >= 4); - - const uint32_t* w = input; - const size_t group_increment = m * 4; - const size_t input_increment = n * 3; - const size_t output_increment = 16 - m * n; - const uint32_t* last_input = (const uint32_t*) ((uintptr_t) input + n * (m - 1)); - uint32_t* last_output = (uint32_t*) ((uintptr_t) output + (m * 4 - 16)); - - for (size_t i = 0; i < m; i += 4) { - w = (const uint32_t*) ((uintptr_t) w + input_increment); - if (w >= last_input) { - w = last_input; - } - const uint32_t* z = (const uint32_t*) ((uintptr_t) w - n); - const uint32_t* y = (const uint32_t*) ((uintptr_t) z - n); - const uint32_t* x = (const uint32_t*) ((uintptr_t) y - n); - - size_t k = n; - while (k >= 16) { - const uint32x4_t vx = vld1q_u32(x); x += 4; - const uint32x4_t vy = vld1q_u32(y); y += 4; - const uint32x4_t vz = vld1q_u32(z); z += 4; - const uint32x4_t vw = vld1q_u32(w); w += 4; - - const uint32x4x2_t vxy = vzipq_u32(vx, vy); - const uint32x4x2_t vzw = vzipq_u32(vz, vw); - - vst1_u32(output, vget_low_u32(vxy.val[0])); - vst1_u32(output + 2, vget_low_u32(vzw.val[0])); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - vst1_u32(output, vget_high_u32(vxy.val[0])); - vst1_u32(output + 2, vget_high_u32(vzw.val[0])); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - vst1_u32(output, vget_low_u32(vxy.val[1])); - vst1_u32(output + 2, vget_low_u32(vzw.val[1])); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - vst1_u32(output, vget_high_u32(vxy.val[1])); - vst1_u32(output + 2, vget_high_u32(vzw.val[1])); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - k -= 16; - } - if XNN_UNLIKELY(k != 0) { - if (k & 8) { - const uint32x2_t vx = vld1_u32(x); x += 2; - const uint32x2_t vy = vld1_u32(y); y += 2; - const uint32x2_t vz = vld1_u32(z); z += 2; - const uint32x2_t vw = vld1_u32(w); w += 2; - - const uint32x2x2_t vxy = vzip_u32(vx, vy); - const uint32x2x2_t vzw = vzip_u32(vz, vw); - - vst1_u32(output, vxy.val[0]); - vst1_u32(output + 2, vzw.val[0]); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - vst1_u32(output, vxy.val[1]); - vst1_u32(output + 2, vzw.val[1]); - output = (uint32_t*) ((uintptr_t) output + group_increment); - } - if (k & 4) { - const uint32x2_t vx = vld1_dup_u32(x); - const uint32x2_t vz = vld1_dup_u32(z); - const uint32x2_t vxy = vld1_lane_u32(y, vx, 1); - const uint32x2_t vzw = vld1_lane_u32(w, vz, 1); w += 1; - - vst1_u32(output, vxy); - vst1_u32(output + 2, vzw); - output = (uint32_t*) ((uintptr_t) output + group_increment); - } - } - output = (uint32_t*) ((uintptr_t) output + output_increment); - if (output > last_output) { - output = last_output; - } - } -} diff --git a/src/x32-zip/x32-zip-xm-scalar.c b/src/x32-zip/x32-zip-xm-scalar.c deleted file mode 100644 index 5d29999fa50f..000000000000 --- a/src/x32-zip/x32-zip-xm-scalar.c +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_xm_ukernel__scalar( - size_t n, - size_t m, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - assert(m >= 4); - - size_t k = n; - do { - size_t l = m; - const uint32_t* input_column = input++; - do { - *output++ = *input_column; - input_column = (uint32_t*) ((uintptr_t) input_column + n); - } while (--l != 0); - k -= 4; - } while (k != 0); -} diff --git a/src/x32-zip/x32-zip-xm-sse2.c b/src/x32-zip/x32-zip-xm-sse2.c deleted file mode 100644 index e5734c6732bd..000000000000 --- a/src/x32-zip/x32-zip-xm-sse2.c +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_xm_ukernel__sse2( - size_t n, - size_t m, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % 4 == 0); - assert(m >= 4); - - const uint32_t* w = input; - const size_t group_increment = m * 4; - const size_t input_increment = n * 3; - const size_t output_increment = 16 - m * n; - const uint32_t* last_input = (const uint32_t*) ((uintptr_t) input + n * (m - 1)); - uint32_t* last_output = (uint32_t*) ((uintptr_t) output + (m * 4 - 16)); - - for (size_t i = 0; i < m; i += 4) { - w = (const uint32_t*) ((uintptr_t) w + input_increment); - if (w >= last_input) { - w = last_input; - } - const uint32_t* z = (const uint32_t*) ((uintptr_t) w - n); - const uint32_t* y = (const uint32_t*) ((uintptr_t) z - n); - const uint32_t* x = (const uint32_t*) ((uintptr_t) y - n); - - size_t k = n; - while (k >= 16) { - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 4; - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 4; - const __m128i vz = _mm_loadu_si128((const __m128i*) z); - z += 4; - const __m128i vw = _mm_loadu_si128((const __m128i*) w); - w += 4; - - const __m128i vxy_lo = _mm_unpacklo_epi32(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi32(vx, vy); - const __m128i vzw_lo = _mm_unpacklo_epi32(vz, vw); - const __m128i vzw_hi = _mm_unpackhi_epi32(vz, vw); - - const __m128i vxyzw0 = _mm_unpacklo_epi64(vxy_lo, vzw_lo); - const __m128i vxyzw1 = _mm_unpackhi_epi64(vxy_lo, vzw_lo); - const __m128i vxyzw2 = _mm_unpacklo_epi64(vxy_hi, vzw_hi); - const __m128i vxyzw3 = _mm_unpackhi_epi64(vxy_hi, vzw_hi); - - _mm_storeu_si128((__m128i*) output, vxyzw0); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - _mm_storeu_si128((__m128i*) output, vxyzw1); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - _mm_storeu_si128((__m128i*) output, vxyzw2); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - _mm_storeu_si128((__m128i*) output, vxyzw3); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - k -= 16; - } - if XNN_UNLIKELY(k != 0) { - if (k & 8) { - const __m128i vx = _mm_loadl_epi64((const __m128i*) x); - x += 2; - const __m128i vy = _mm_loadl_epi64((const __m128i*) y); - y += 2; - const __m128i vz = _mm_loadl_epi64((const __m128i*) z); - z += 2; - const __m128i vw = _mm_loadl_epi64((const __m128i*) w); - w += 2; - - const __m128i vxy = _mm_unpacklo_epi32(vx, vy); - const __m128i vzw = _mm_unpacklo_epi32(vz, vw); - - const __m128i vxyzw_lo = _mm_unpacklo_epi64(vxy, vzw); - const __m128i vxyzw_hi = _mm_unpackhi_epi64(vxy, vzw); - - _mm_storeu_si128((__m128i*) output, vxyzw_lo); - output = (uint32_t*) ((uintptr_t) output + group_increment); - - _mm_storeu_si128((__m128i*) output, vxyzw_hi); - output = (uint32_t*) ((uintptr_t) output + group_increment); - } - if (k & 4) { - const uint32_t vx = *x; - const uint32_t vy = *y; - const uint32_t vz = *z; - const uint32_t vw = *w++; - - output[0] = vx; - output[1] = vy; - output[2] = vz; - output[3] = vw; - output = (uint32_t*) ((uintptr_t) output + group_increment); - } - } - output = (uint32_t*) ((uintptr_t) output + output_increment); - if (output > last_output) { - output = last_output; - } - } -} diff --git a/src/x32-zip/x32-zip-xm-wasmsimd.c b/src/x32-zip/x32-zip-xm-wasmsimd.c deleted file mode 100644 index 69d86ce317c7..000000000000 --- a/src/x32-zip/x32-zip-xm-wasmsimd.c +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2020 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include - -#include "xnnpack/zip.h" - - -void xnn_x32_zip_xm_ukernel__wasmsimd( - size_t n, - size_t m, - const uint32_t* input, - uint32_t* output) -{ - assert(n != 0); - assert(n % sizeof(uint32_t) == 0); - assert(m >= 4); - - const float* w = (const float*) input; - float* o = (float*) output; - const size_t group_increment = m * 4; - const size_t input_increment = n * 3; - const size_t output_increment = 4 * sizeof(uint32_t) - m * n; - const float* last_input = (const float*) ((uintptr_t) input + n * (m - 1)); - float* last_output = (float*) ((uintptr_t) output + (m * 4 - 4 * sizeof(uint32_t))); - - for (size_t i = 0; i < m; i += 4) { - w = (const float*) ((uintptr_t) w + input_increment); - if (w >= last_input) { - w = last_input; - } - const float* z = (const float*) ((uintptr_t) w - n); - const float* y = (const float*) ((uintptr_t) z - n); - const float* x = (const float*) ((uintptr_t) y - n); - - size_t k = n; - while (k >= 4 * sizeof(uint32_t)) { - const v128_t vx = wasm_v128_load((const v128_t*) x); - x += 4; - const v128_t vy = wasm_v128_load((const v128_t*) y); - y += 4; - const v128_t vz = wasm_v128_load((const v128_t*) z); - z += 4; - const v128_t vw = wasm_v128_load((const v128_t*) w); - w += 4; - - const v128_t vxy_lo = wasm_v32x4_shuffle(vx, vy, 0, 4, 1, 5); - const v128_t vxy_hi = wasm_v32x4_shuffle(vx, vy, 2, 6, 3, 7); - const v128_t vzw_lo = wasm_v32x4_shuffle(vz, vw, 0, 4, 1, 5); - const v128_t vzw_hi = wasm_v32x4_shuffle(vz, vw, 2, 6, 3, 7); - - const v128_t vxyzw0 = wasm_v32x4_shuffle(vxy_lo, vzw_lo, 0, 1, 4, 5); - const v128_t vxyzw1 = wasm_v32x4_shuffle(vxy_lo, vzw_lo, 2, 3, 6, 7); - const v128_t vxyzw2 = wasm_v32x4_shuffle(vxy_hi, vzw_hi, 0, 1, 4, 5); - const v128_t vxyzw3 = wasm_v32x4_shuffle(vxy_hi, vzw_hi, 2, 3, 6, 7); - - wasm_v128_store(o, vxyzw0); - o = (float*) ((uintptr_t) o + group_increment); - - wasm_v128_store(o, vxyzw1); - o = (float*) ((uintptr_t) o + group_increment); - - wasm_v128_store(o, vxyzw2); - o = (float*) ((uintptr_t) o + group_increment); - - wasm_v128_store(o, vxyzw3); - o = (float*) ((uintptr_t) o + group_increment); - - k -= 4 * sizeof(uint32_t); - } - if XNN_UNLIKELY(k != 0) { - if (k & (2 * sizeof(uint32_t))) { - const double vx = *((const double*) x); - x += 2; - const double vy = *((const double*) y); - y += 2; - const double vz = *((const double*) z); - z += 2; - const double vw = *((const double*) w); - w += 2; - - const v128_t vxy = wasm_f64x2_make(vx, vy); - const v128_t vzw = wasm_f64x2_make(vz, vw); - - const v128_t vxyzw_lo = wasm_v32x4_shuffle(vxy, vzw, 0, 2, 4, 6); - const v128_t vxyzw_hi = wasm_v32x4_shuffle(vxy, vzw, 1, 3, 5, 7); - - wasm_v128_store(o, vxyzw_lo); - o = (float*) ((uintptr_t) o + group_increment); - - wasm_v128_store(o, vxyzw_hi); - o = (float*) ((uintptr_t) o + group_increment); - } - if (k & (1 * sizeof(uint32_t))) { - const float vx = *x; - const float vy = *y; - const float vz = *z; - const float vw = *w++; - - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - o = (float*) ((uintptr_t) o + group_increment); - } - } - o = (float*) ((uintptr_t) o + output_increment); - if (o > last_output) { - o = last_output; - } - } -} diff --git a/src/x8-packw/c4-avxvnni.c.in b/src/x8-packw/c4-avxvnni.c.in new file mode 100644 index 000000000000..da17c54f7e78 --- /dev/null +++ b/src/x8-packw/c4-avxvnni.c.in @@ -0,0 +1,366 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR == 64 +$assert KR == 4 +$assert DATATYPE in ["QS8", "X8"] +$assert TYPE in ["int8_t"] +$assert IZP in [0, 128] +$UNROLL = 0 + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +$if PREFETCH: + #include "xnnpack/prefetch.h" + +XNN_INLINE static uint32_t safe_load_u32(const void* src, size_t k) { + uint32_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < k; ++i) { + value |= (uint32_t) bytes[i] << (i * 8); + } + return value; +} + +$BTYPE = {"QS8": "int32_t", "QS4": "int32_t", "X8": "uint32_t"}[DATATYPE] +$WTYPE = {"QS8": "int8_t", "QS4": "uint8_t", "X8": "int8_t"}[DATATYPE] +$PACKEDWTYPE = {"QS8": "int8_t", "QS4": "void", "X8": "int8_t"}[DATATYPE] +$SCALETYPE = {"QS8": "void", "QS4": "float", "X8": "void"}[DATATYPE] +$PARAMTYPE = {"QS8": "void", "QS4": "struct xnn_qs8_qc4w_packing_params", "X8": "void"}[DATATYPE] +$if DATATYPE in ["QS8", "QS4"]: + $_MM256_DPBUSD_EPI32 = "mm256_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32" + $ISA = "avx2" if VARIANT == "MADD" else "avxvnni" if AVX == 2 else "avx256vnni" +$else: + $ISA = "avx2" if AVX == 2 else "avx256skx" +$DATATYPE_SPEC = "qs8_to_qu8" if IZP == 128 else {"QS8": "qs8", "QS4": "qs8_qc4w", "X8": "x8"}[DATATYPE] +$if DATATYPE in ["QS4"]: + // Convert a vector from packed nibbles to planar, and accumulate sum + static XNN_INTRINSIC + __m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = ${_MM256_DPBUSD_EPI32}(*vacc, vone, v01); + *vacc = ${_MM256_DPBUSD_EPI32}(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); + } + +void xnn_${DATATYPE_SPEC}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_prfm" if PREFETCH else ""}( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const ${WTYPE}* weights, + const ${BTYPE}* bias, + const ${SCALETYPE}* scale, + ${PACKEDWTYPE}* packed_weights, + size_t extra_bytes, + const ${PARAMTYPE}* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + assert(params != NULL); + $if DATATYPE == "QS4": + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes + + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + + $if DATATYPE in ["QS8"]: + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP})); + $elif DATATYPE in ["QS4"]: + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + ${IZP}); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + + $if DATATYPE in ["QS8", "QS4"]: + ${BTYPE}* packed_b = (${BTYPE}*) out; + if XNN_LIKELY(b != NULL) { + $for N in range(0, NR, 8): + const __m256i vb${N} = _mm256_loadu_si256((const __m256i*) (b + ${N})); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), vb${N}); + b += ${NR}; + } else { + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256()); + } + out += ${NR} * sizeof(${BTYPE}); + + $if PREFETCH: + $for N in range(0, NR): + $for OFFSET in range(0, 448, 64): + xnn_prefetch_to_l1((const int8_t*) w${N} + ${OFFSET}); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vacc${N} = _mm256_setzero_si256(); + + size_t k = kc; + $if UNROLL: + // KC main loop multiple of ${NR}x${8 * KR} + for (; k >= ${8 * KR}; k -= ${8 * KR}) { + $for N in range(NR): + const __m256i v${N}_01234567 = _mm256_loadu_si256((const __m256i*) w${N}); + + $for N in range(0, NR, 2): + const __m256i v${N}${N+1}_0145 = _mm256_unpacklo_epi32(v${N}_01234567, v${N+1}_01234567); + const __m256i v${N}${N+1}_2367 = _mm256_unpackhi_epi32(v${N}_01234567, v${N+1}_01234567); + + $for N in range(0, NR, 4): + const __m256i v${N}${N+2}_02 = _mm256_unpacklo_epi64(v${N}${N+1}_0145, v${N+2}${N+3}_0145); + const __m256i v${N}${N+2}_13 = _mm256_unpackhi_epi64(v${N}${N+1}_0145, v${N+2}${N+3}_0145); + const __m256i v${N+1}${N+3}_02 = _mm256_unpacklo_epi64(v${N}${N+1}_2367, v${N+2}${N+3}_2367); + const __m256i v${N+1}${N+3}_13 = _mm256_unpackhi_epi64(v${N}${N+1}_2367, v${N+2}${N+3}_2367); + + $for N in range(0, NR // 4): + $for I in range(0, 2): + $C = N*2+I + const __m256i v${C}${C+4}_0 = _mm256_permute2f128_si256(v${N}${N+2}_${I}${I+2}, v${N+4}${N+6}_${I}${I+2}, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v${C}${C+4}_1 = _mm256_permute2f128_si256(v${N}${N+2}_${I}${I+2}, v${N+4}${N+6}_${I}${I+2}, _MM_SHUFFLE(0, 3, 0, 1)); + + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + $for I in range(0, 2): + $for J in range(0, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${J}${J+4}_${I}); + + $for I in range(0, 2): + $for N in range(0, KR): + _mm256_storeu_si256((__m256i *)&out[${(I*KR + N)*8*KR}], v${N}${N+4}_${I}); + + $for N in range(NR): + w${N} += ${8 * KR}; + out += ${8*NR*KR}; + } + + // KC main loop multiple of ${NR}x${KR} + for (; k >= ${KR}; k -= ${KR}) { + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N})); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+1})), 0x02); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+2})), 0x04); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+3})), 0x08); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+4})), 0x10); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+5})), 0x20); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+6})), 0x40); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+7})), 0x80); + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 8): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + assert(k >= 1 && k <= ${KR-1}); + + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) safe_load_u32(w${N}, k)); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+1}, k)), 0x02); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+2}, k)), 0x04); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+3}, k)), 0x08); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+4}, k)), 0x10); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+5}, k)), 0x20); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+6}, k)), 0x40); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+7}, k)), 0x80); + + $for N in range(NR): + w${N} += k; + + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 8): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + out += ${NR*KR}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_mullo_epi32(vacc${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w${NR-1}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= ${NR-1}); + + $if DATATYPE in ["QS8", "QS4"]: + ${BTYPE}* packed_b = (${BTYPE}*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((${BTYPE}*) out)[nb] = b[nb]; + } + b += n; + } else { + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256()); + } + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vacc${N} = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of ${NR}x${KR} + for (; k >= ${KR}; k -= ${KR}) { + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N})); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+1})), 0x02); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+2})), 0x04); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+3})), 0x08); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+4})), 0x10); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+5})), 0x20); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+6})), 0x40); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+7})), 0x80); + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 8): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + assert(k >= 1 && k <= ${KR-1}); + + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) safe_load_u32(w${N}, k)); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+1}, k)), 0x02); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+2}, k)), 0x04); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+3}, k)), 0x08); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+4}, k)), 0x10); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+5}, k)), 0x20); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+6}, k)), 0x40); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+7}, k)), 0x80); + + $for N in range(NR): + w${N} += k; + + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 8): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + out += ${NR*KR}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_mullo_epi32(vacc${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} \ No newline at end of file diff --git a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2-prfm.c b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2-prfm.c index d17f70563dec..7fda009a6815 100644 --- a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2-prfm.c +++ b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,6 +51,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -51,18 +62,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -78,6 +77,19 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( const int8_t* w13 = w12 + kc; const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -246,22 +258,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( xnn_prefetch_to_l1((const int8_t*) w14 + 448); xnn_prefetch_to_l1((const int8_t*) w15 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -364,28 +376,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -418,24 +424,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -496,42 +488,242 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( if XNN_UNPREDICTABLE(n < 16) { w15 = w14; } - xnn_prefetch_to_l1((const int8_t*) w0); + + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); - xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); xnn_prefetch_to_l1((const int8_t*) w8 + 64); - xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); xnn_prefetch_to_l1((const int8_t*) w9 + 64); - xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); xnn_prefetch_to_l1((const int8_t*) w10 + 64); - xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); xnn_prefetch_to_l1((const int8_t*) w11 + 64); - xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); xnn_prefetch_to_l1((const int8_t*) w12 + 64); - xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); xnn_prefetch_to_l1((const int8_t*) w13 + 64); - xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); xnn_prefetch_to_l1((const int8_t*) w14 + 64); - xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -595,28 +787,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -644,9 +830,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2_prfm( out += 128; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2.c b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2.c index a26b48e1f375..83a3c9fcc0cc 100644 --- a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2.c +++ b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx2.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( size_t g, @@ -30,7 +40,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,6 +50,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -50,18 +61,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -78,6 +77,19 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + size_t k = kc; // KC main loop multiple of 16x32 @@ -117,22 +129,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -219,28 +231,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -273,24 +279,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -352,9 +344,113 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( w15 = w14; } + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -402,28 +498,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -451,9 +541,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx2( out += 128; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx-prfm.c b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx-prfm.c index bb3d53bee4a8..4cd808f6f7ad 100644 --- a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx-prfm.c +++ b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,6 +51,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -51,18 +62,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -78,6 +77,19 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( const int8_t* w13 = w12 + kc; const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -246,22 +258,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( xnn_prefetch_to_l1((const int8_t*) w14 + 448); xnn_prefetch_to_l1((const int8_t*) w15 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -364,28 +376,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -418,24 +424,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -496,42 +488,242 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( if XNN_UNPREDICTABLE(n < 16) { w15 = w14; } - xnn_prefetch_to_l1((const int8_t*) w0); + + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); - xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); + xnn_prefetch_to_l1((const int8_t*) w8 + 0); xnn_prefetch_to_l1((const int8_t*) w8 + 64); - xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 192); + xnn_prefetch_to_l1((const int8_t*) w8 + 256); + xnn_prefetch_to_l1((const int8_t*) w8 + 320); + xnn_prefetch_to_l1((const int8_t*) w8 + 384); + xnn_prefetch_to_l1((const int8_t*) w9 + 0); xnn_prefetch_to_l1((const int8_t*) w9 + 64); - xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 192); + xnn_prefetch_to_l1((const int8_t*) w9 + 256); + xnn_prefetch_to_l1((const int8_t*) w9 + 320); + xnn_prefetch_to_l1((const int8_t*) w9 + 384); + xnn_prefetch_to_l1((const int8_t*) w10 + 0); xnn_prefetch_to_l1((const int8_t*) w10 + 64); - xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 192); + xnn_prefetch_to_l1((const int8_t*) w10 + 256); + xnn_prefetch_to_l1((const int8_t*) w10 + 320); + xnn_prefetch_to_l1((const int8_t*) w10 + 384); + xnn_prefetch_to_l1((const int8_t*) w11 + 0); xnn_prefetch_to_l1((const int8_t*) w11 + 64); - xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 192); + xnn_prefetch_to_l1((const int8_t*) w11 + 256); + xnn_prefetch_to_l1((const int8_t*) w11 + 320); + xnn_prefetch_to_l1((const int8_t*) w11 + 384); + xnn_prefetch_to_l1((const int8_t*) w12 + 0); xnn_prefetch_to_l1((const int8_t*) w12 + 64); - xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 192); + xnn_prefetch_to_l1((const int8_t*) w12 + 256); + xnn_prefetch_to_l1((const int8_t*) w12 + 320); + xnn_prefetch_to_l1((const int8_t*) w12 + 384); + xnn_prefetch_to_l1((const int8_t*) w13 + 0); xnn_prefetch_to_l1((const int8_t*) w13 + 64); - xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 192); + xnn_prefetch_to_l1((const int8_t*) w13 + 256); + xnn_prefetch_to_l1((const int8_t*) w13 + 320); + xnn_prefetch_to_l1((const int8_t*) w13 + 384); + xnn_prefetch_to_l1((const int8_t*) w14 + 0); xnn_prefetch_to_l1((const int8_t*) w14 + 64); - xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 192); + xnn_prefetch_to_l1((const int8_t*) w14 + 256); + xnn_prefetch_to_l1((const int8_t*) w14 + 320); + xnn_prefetch_to_l1((const int8_t*) w14 + 384); + xnn_prefetch_to_l1((const int8_t*) w15 + 0); xnn_prefetch_to_l1((const int8_t*) w15 + 64); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 192); + xnn_prefetch_to_l1((const int8_t*) w15 + 256); + xnn_prefetch_to_l1((const int8_t*) w15 + 320); + xnn_prefetch_to_l1((const int8_t*) w15 + 384); - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + xnn_prefetch_to_l1((const int8_t*) w8 + 448); + xnn_prefetch_to_l1((const int8_t*) w9 + 448); + xnn_prefetch_to_l1((const int8_t*) w10 + 448); + xnn_prefetch_to_l1((const int8_t*) w11 + 448); + xnn_prefetch_to_l1((const int8_t*) w12 + 448); + xnn_prefetch_to_l1((const int8_t*) w13 + 448); + xnn_prefetch_to_l1((const int8_t*) w14 + 448); + xnn_prefetch_to_l1((const int8_t*) w15 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -595,28 +787,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -644,9 +830,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx_prfm( out += 128; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx.c b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx.c index f8febc539695..4096ac35646c 100644 --- a/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx.c +++ b/src/x8-packw/gen/x8-packw-x16c8-gemm-goi-avx256skx.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( size_t g, @@ -30,7 +40,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,6 +50,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -50,18 +61,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 16; n -= 16) { - if XNN_LIKELY(b != NULL) { - const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); - const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); - _mm256_storeu_si256((__m256i*) (out + 0), vb0); - _mm256_storeu_si256((__m256i*) (out + 32), vb8); - b += 16; - } else { - _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); - _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); - } - out += 16 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; const int8_t* w2 = w1 + kc; const int8_t* w3 = w2 + kc; @@ -78,6 +77,19 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( const int8_t* w14 = w13 + kc; const int8_t* w15 = w14 + kc; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + size_t k = kc; // KC main loop multiple of 16x32 @@ -117,22 +129,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -219,28 +231,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -273,24 +279,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( } // NC remainder (1..15) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 15); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (16 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -352,9 +344,113 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( w15 = w14; } + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + - // KC main loop multiple of 16x8 size_t k = kc; + // KC main loop multiple of 16x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + const __m256i v8_0123 = _mm256_loadu_si256((const __m256i*) w8); + const __m256i v9_0123 = _mm256_loadu_si256((const __m256i*) w9); + const __m256i v10_0123 = _mm256_loadu_si256((const __m256i*) w10); + const __m256i v11_0123 = _mm256_loadu_si256((const __m256i*) w11); + const __m256i v12_0123 = _mm256_loadu_si256((const __m256i*) w12); + const __m256i v13_0123 = _mm256_loadu_si256((const __m256i*) w13); + const __m256i v14_0123 = _mm256_loadu_si256((const __m256i*) w14); + const __m256i v15_0123 = _mm256_loadu_si256((const __m256i*) w15); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + const __m256i v89_02 = _mm256_unpacklo_epi64(v8_0123, v9_0123); + const __m256i v89_13 = _mm256_unpackhi_epi64(v8_0123, v9_0123); + const __m256i v1011_02 = _mm256_unpacklo_epi64(v10_0123, v11_0123); + const __m256i v1011_13 = _mm256_unpackhi_epi64(v10_0123, v11_0123); + const __m256i v1213_02 = _mm256_unpacklo_epi64(v12_0123, v13_0123); + const __m256i v1213_13 = _mm256_unpackhi_epi64(v12_0123, v13_0123); + const __m256i v1415_02 = _mm256_unpacklo_epi64(v14_0123, v15_0123); + const __m256i v1415_13 = _mm256_unpackhi_epi64(v14_0123, v15_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_0 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_1 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v8_2 = _mm256_permute2f128_si256(v89_02, v1011_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v8_3 = _mm256_permute2f128_si256(v89_13, v1011_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_0 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_1 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v12_2 = _mm256_permute2f128_si256(v1213_02, v1415_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v12_3 = _mm256_permute2f128_si256(v1213_13, v1415_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v8_0); + _mm256_storeu_si256((__m256i *)&out[96], v12_0); + _mm256_storeu_si256((__m256i *)&out[128], v0_1); + _mm256_storeu_si256((__m256i *)&out[160], v4_1); + _mm256_storeu_si256((__m256i *)&out[192], v8_1); + _mm256_storeu_si256((__m256i *)&out[224], v12_1); + _mm256_storeu_si256((__m256i *)&out[256], v0_2); + _mm256_storeu_si256((__m256i *)&out[288], v4_2); + _mm256_storeu_si256((__m256i *)&out[320], v8_2); + _mm256_storeu_si256((__m256i *)&out[352], v12_2); + _mm256_storeu_si256((__m256i *)&out[384], v0_3); + _mm256_storeu_si256((__m256i *)&out[416], v4_3); + _mm256_storeu_si256((__m256i *)&out[448], v8_3); + _mm256_storeu_si256((__m256i *)&out[480], v12_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + w8 += 32; + w9 += 32; + w10 += 32; + w11 += 32; + w12 += 32; + w13 += 32; + w14 += 32; + w15 += 32; + out += 512; + } + + // KC main loop multiple of 16x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -402,28 +498,22 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); - __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); - v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); - v8 = _mm256_and_si256(v8, vmask); - __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); - v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); - v12 = _mm256_and_si256(v12, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) safe_load_u64(w8, k)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w9, k)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w10, k)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) safe_load_u64(w11, k)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) safe_load_u64(w12, k)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w13, k)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w14, k)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) safe_load_u64(w15, k)), 0xC0); w0 += k; w1 += k; @@ -451,9 +541,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x16c8__avx256skx( out += 128; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2-prfm.c b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2-prfm.c index 1eade6b0884b..3507403b0fda 100644 --- a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2-prfm.c +++ b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,6 +51,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -51,6 +62,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); _mm256_storeu_si256((__m256i*) (out + 0), vb0); @@ -60,13 +79,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( } out += 8 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -155,14 +167,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 448); xnn_prefetch_to_l1((const int8_t*) w7 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -223,18 +235,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -257,24 +265,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -303,26 +297,137 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( if XNN_UNPREDICTABLE(n < 8) { w7 = w6; } - xnn_prefetch_to_l1((const int8_t*) w0); + + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -360,18 +465,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -389,9 +490,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2_prfm( out += 64; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2.c b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2.c index cc9dd2b7d125..140623e8de31 100644 --- a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2.c +++ b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx2.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( size_t g, @@ -30,7 +40,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,6 +50,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -50,6 +61,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); _mm256_storeu_si256((__m256i*) (out + 0), vb0); @@ -59,13 +78,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( } out += 8 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; size_t k = kc; @@ -90,14 +102,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -150,18 +162,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -184,24 +192,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -231,9 +225,72 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( w7 = w6; } + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -263,18 +320,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -292,9 +345,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx2( out += 64; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx-prfm.c b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx-prfm.c index 5fe29cf7e090..68e6771c3415 100644 --- a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx-prfm.c +++ b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx-prfm.c @@ -18,6 +18,16 @@ #include "xnnpack/unaligned.h" #include "xnnpack/prefetch.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( size_t g, @@ -31,7 +41,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -41,6 +51,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -51,6 +62,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); _mm256_storeu_si256((__m256i*) (out + 0), vb0); @@ -60,13 +79,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( } out += 8 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); xnn_prefetch_to_l1((const int8_t*) w0 + 128); @@ -155,14 +167,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 448); xnn_prefetch_to_l1((const int8_t*) w7 + 448); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -223,18 +235,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -257,24 +265,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -303,26 +297,137 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( if XNN_UNPREDICTABLE(n < 8) { w7 = w6; } - xnn_prefetch_to_l1((const int8_t*) w0); + + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + xnn_prefetch_to_l1((const int8_t*) w0 + 0); xnn_prefetch_to_l1((const int8_t*) w0 + 64); - xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w0 + 256); + xnn_prefetch_to_l1((const int8_t*) w0 + 320); + xnn_prefetch_to_l1((const int8_t*) w0 + 384); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); xnn_prefetch_to_l1((const int8_t*) w1 + 64); - xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 256); + xnn_prefetch_to_l1((const int8_t*) w1 + 320); + xnn_prefetch_to_l1((const int8_t*) w1 + 384); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); xnn_prefetch_to_l1((const int8_t*) w2 + 64); - xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 256); + xnn_prefetch_to_l1((const int8_t*) w2 + 320); + xnn_prefetch_to_l1((const int8_t*) w2 + 384); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); xnn_prefetch_to_l1((const int8_t*) w3 + 64); - xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 256); + xnn_prefetch_to_l1((const int8_t*) w3 + 320); + xnn_prefetch_to_l1((const int8_t*) w3 + 384); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); xnn_prefetch_to_l1((const int8_t*) w4 + 64); - xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 256); + xnn_prefetch_to_l1((const int8_t*) w4 + 320); + xnn_prefetch_to_l1((const int8_t*) w4 + 384); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); xnn_prefetch_to_l1((const int8_t*) w5 + 64); - xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 256); + xnn_prefetch_to_l1((const int8_t*) w5 + 320); + xnn_prefetch_to_l1((const int8_t*) w5 + 384); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); xnn_prefetch_to_l1((const int8_t*) w6 + 64); - xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 256); + xnn_prefetch_to_l1((const int8_t*) w6 + 320); + xnn_prefetch_to_l1((const int8_t*) w6 + 384); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 256); + xnn_prefetch_to_l1((const int8_t*) w7 + 320); + xnn_prefetch_to_l1((const int8_t*) w7 + 384); - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -360,18 +465,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -389,9 +490,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx_prfm( out += 64; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx.c b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx.c index 50fcde9adedb..76d9d6d29c8a 100644 --- a/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx.c +++ b/src/x8-packw/gen/x8-packw-x8c8-gemm-goi-avx256skx.c @@ -17,6 +17,16 @@ #include "xnnpack/packw.h" #include "xnnpack/unaligned.h" +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} + void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( size_t g, @@ -30,7 +40,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( const void* scale, int8_t* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -40,6 +50,7 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); int8_t* out = (int8_t*) packed_weights; const uint32_t* b = (const uint32_t*) bias; @@ -50,6 +61,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + if XNN_LIKELY(b != NULL) { const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); _mm256_storeu_si256((__m256i*) (out + 0), vb0); @@ -59,13 +78,6 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( } out += 8 * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - const int8_t* w2 = w1 + kc; - const int8_t* w3 = w2 + kc; - const int8_t* w4 = w3 + kc; - const int8_t* w5 = w4 + kc; - const int8_t* w6 = w5 + kc; - const int8_t* w7 = w6 + kc; size_t k = kc; @@ -90,14 +102,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); - const __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); - const __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); - const __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); _mm256_storeu_si256((__m256i *)&out[0], v0_0); @@ -150,18 +162,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -184,24 +192,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( } // NC remainder (1..7) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= 7); - - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((uint32_t*) out) = *b++; - out += sizeof(uint32_t); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((uint32_t*) out) = 0; - out += sizeof(uint32_t); - } while (--nb != 0); - } - out += (8 - n) * sizeof(uint32_t); - + // Clamp weight pointers for NC remainder const int8_t* w1 = w0 + kc; if XNN_UNPREDICTABLE(n < 2) { w1 = w0; @@ -231,9 +225,72 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( w7 = w6; } + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((uint32_t*) out)[nb] = b[nb]; + } + b += n; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(uint32_t); + + - // KC main loop multiple of 8x8 size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123); + const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123); + const __m256i v23_02 = _mm256_unpacklo_epi64(v2_0123, v3_0123); + const __m256i v23_13 = _mm256_unpackhi_epi64(v2_0123, v3_0123); + const __m256i v45_02 = _mm256_unpacklo_epi64(v4_0123, v5_0123); + const __m256i v45_13 = _mm256_unpackhi_epi64(v4_0123, v5_0123); + const __m256i v67_02 = _mm256_unpacklo_epi64(v6_0123, v7_0123); + const __m256i v67_13 = _mm256_unpackhi_epi64(v6_0123, v7_0123); + + + __m256i v0_0 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_1 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v0_2 = _mm256_permute2f128_si256(v01_02, v23_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v0_3 = _mm256_permute2f128_si256(v01_13, v23_13, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_0 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_1 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + _mm256_storeu_si256((__m256i *)&out[0], v0_0); + _mm256_storeu_si256((__m256i *)&out[32], v4_0); + _mm256_storeu_si256((__m256i *)&out[64], v0_1); + _mm256_storeu_si256((__m256i *)&out[96], v4_1); + _mm256_storeu_si256((__m256i *)&out[128], v0_2); + _mm256_storeu_si256((__m256i *)&out[160], v4_2); + _mm256_storeu_si256((__m256i *)&out[192], v0_3); + _mm256_storeu_si256((__m256i *)&out[224], v4_3); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x8 for (; k >= 8; k -= 8) { __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); @@ -263,18 +320,14 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( if (k != 0) { assert(k >= 1 && k <= 7); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - - __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - v0 = _mm256_and_si256(v0, vmask); - __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - v4 = _mm256_and_si256(v4, vmask); + __m256i v0 = _mm256_set1_epi64x((int64_t) safe_load_u64(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w1, k)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w2, k)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) safe_load_u64(w3, k)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) safe_load_u64(w4, k)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w5, k)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w6, k)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) safe_load_u64(w7, k)), 0xC0); w0 += k; w1 += k; @@ -292,9 +345,10 @@ void xnn_x8_packw_gemm_goi_ukernel_x8c8__avx256skx( out += 64; } + out = (int8_t*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const int8_t*)((intptr_t) weights + nc * kc); } while (--g != 0); } diff --git a/src/x8-packw/kr-avxvnni.c.in b/src/x8-packw/kr-avxvnni.c.in index d3e60e954905..5deb4b581d88 100644 --- a/src/x8-packw/kr-avxvnni.c.in +++ b/src/x8-packw/kr-avxvnni.c.in @@ -5,7 +5,7 @@ $assert NR in [8, 16] $assert KR == 8 -$assert DATATYPE in ["QS8", "X8"] +$assert DATATYPE in ["QS8", "X8", "QS4"] $assert TYPE in ["int8_t"] $assert IZP in [0, 128] @@ -23,22 +23,51 @@ $if VARIANT == "MADD": // AVXVNNI replacement that uses vpmaddubsw. // u7 is vone. s8 is int8 weights. static XNN_INTRINSIC - __m256i _mm256_dpbusd_epi32_madd(__m256i i32, const __m256i u7, const __m256i s8) { + __m256i mm256_dpbusd_epi32_madd(__m256i i32, const __m256i u7, const __m256i s8) { const __m256i vone = _mm256_set1_epi16(1); const __m256i i16 = _mm256_maddubs_epi16(u7, s8); // u7 * s8 = s16 const __m256i v = _mm256_madd_epi16(i16, vone); // convert 16 bits to 32 bits return _mm256_add_epi32(i32, v); } +XNN_INLINE static uint64_t safe_load_u64(const void* address, size_t n) { + uint64_t value = 0; + assert(n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + return value; +} -$BTYPE = {"QS8": "int32_t", "X8": "uint32_t"}[DATATYPE] -$WTYPE = "int8_t" -$if DATATYPE in ["QS8"]: - $_MM256_DPBUSD_EPI32 = "_mm256_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32" +$BTYPE = {"QS8": "int32_t", "QS4": "int32_t", "X8": "uint32_t"}[DATATYPE] +$WTYPE = {"QS8": "int8_t", "QS4": "uint8_t", "X8": "int8_t"}[DATATYPE] +$PACKEDWTYPE = {"QS8": "int8_t", "QS4": "void", "X8": "int8_t"}[DATATYPE] +$SCALETYPE = {"QS8": "void", "QS4": "float", "X8": "void"}[DATATYPE] +$PARAMTYPE = {"QS8": "void", "QS4": "struct xnn_qs8_qc4w_packing_params", "X8": "void"}[DATATYPE] +$if DATATYPE in ["QS8", "QS4"]: + $_MM256_DPBUSD_EPI32 = "mm256_dpbusd_epi32_madd" if VARIANT == "MADD" else "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32" $ISA = "avx2" if VARIANT == "MADD" else "avxvnni" if AVX == 2 else "avx256vnni" $else: $ISA = "avx2" if AVX == 2 else "avx256skx" -void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VARIANT == "MADD" else ""}${"_prfm" if PREFETCH else ""}( +$DATATYPE_SPEC = "qs8_to_qu8" if IZP == 128 else {"QS8": "qs8", "QS4": "qs8_qc4w", "X8": "x8"}[DATATYPE] +$if DATATYPE in ["QS4"]: + // Convert a vector from packed nibbles to planar, and accumulate sum + static XNN_INTRINSIC + __m256i xnn_packed2planar(__m256i* vacc, const __m256i v, const __m256i vmask, const __m256i vone) { + const __m256i v0213 = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 1, 2, 0)); + const __m256i vt = _mm256_slli_epi32(v0213, 4); // isolate lower int4 + const __m256i vh = _mm256_and_si256(v0213, vmask); // isolate upper int4 + const __m256i vl = _mm256_and_si256(vt, vmask); + const __m256i v01 = _mm256_unpacklo_epi8(vl, vh); + const __m256i v23 = _mm256_unpackhi_epi8(vl, vh); + *vacc = ${_MM256_DPBUSD_EPI32}(*vacc, vone, v01); + *vacc = ${_MM256_DPBUSD_EPI32}(*vacc, vone, v23); + const __m256i vl01 = _mm256_srli_epi32(v01, 4); + return _mm256_or_si256(vl01, v23); + } + +void xnn_${DATATYPE_SPEC}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VARIANT == "MADD" else ""}${"_prfm" if PREFETCH else ""}( size_t g, size_t nc, size_t kc, @@ -47,10 +76,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk size_t sr, const ${WTYPE}* weights, const ${BTYPE}* bias, - const void* scale, - ${WTYPE}* packed_weights, + const ${SCALETYPE}* scale, + ${PACKEDWTYPE}* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const ${PARAMTYPE}* params) { assert(g != 0); assert(nc != 0); @@ -60,21 +89,33 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk assert(sr == 1); assert(weights != NULL); assert(packed_weights != NULL); + assert(params != NULL); + $if DATATYPE == "QS4": + assert(kc % 2 == 0); // This kernel does not support odd KC + kc >>= 1; // KR=8 4 bit with 2 planes is 8 bytes. Measure in bytes ${TYPE}* out = (${TYPE}*) packed_weights; const ${BTYPE}* b = (const ${BTYPE}*) bias; $if DATATYPE in ["QS8"]: const __m256i vone = _mm256_set1_epi8(1); - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP}); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP})); + $elif DATATYPE in ["QS4"]: + const __m256i vone = _mm256_set1_epi8(1); + const __m256i vmask = _mm256_set1_epi8(0xF0); + const __m256i vzeropoint = _mm256_set1_epi32((int32_t) params->input_zero_point + ${IZP}); + const __m256i vkernel_zero_point = _mm256_set1_epi32((uint32_t) params->kernel_zero_point * 0x11111111); + assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0); do { // NC main loop multiple of ${NR} const ${TYPE}* w0 = (const ${TYPE}*) weights; size_t n = nc; for (;n >= ${NR}; n -= ${NR}) { - $if DATATYPE in ["QS8"]: + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + + $if DATATYPE in ["QS8", "QS4"]: ${BTYPE}* packed_b = (${BTYPE}*) out; if XNN_LIKELY(b != NULL) { $for N in range(0, NR, 8): @@ -88,14 +129,12 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk } out += ${NR} * sizeof(${BTYPE}); - $for N in range(1, NR): - const ${TYPE}* w${N} = w${N-1} + kc; $if PREFETCH: $for N in range(0, NR): $for OFFSET in range(0, 448, 64): xnn_prefetch_to_l1((const int8_t*) w${N} + ${OFFSET}); - $if DATATYPE in ["QS8"]: + $if DATATYPE in ["QS8", "QS4"]: $for N in range(0, NR, 4): __m256i vacc${N} = _mm256_setzero_si256(); @@ -103,7 +142,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk // KC main loop multiple of ${NR}x${4 * KR} for (; k >= ${4 * KR}; k -= ${4 * KR}) { $for N in range(NR): - const __m256i v${N}_0123 = _mm256_loadu_si256((const __m256i*) w${N}); + $if DATATYPE in ["QS4"]: + const __m256i v${N}_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w${N}), vkernel_zero_point); // uint4 -> int4 + $else: + const __m256i v${N}_0123 = _mm256_loadu_si256((const __m256i*) w${N}); $for N in range(0, NR, 2): const __m256i v${N}${N+1}_02 = _mm256_unpacklo_epi64(v${N}_0123, v${N+1}_0123); @@ -115,12 +157,16 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk $for N in range(0, NR, 4): $for I in range(0, 4): - const __m256i v${N}_${I} = _mm256_permute2f128_si256(v${N}${N+1}_${I%2}${I%2+2}, v${N+2}${N+3}_${I%2}${I%2+2}, _MM_SHUFFLE(0, ${I//2+2}, 0, ${I//2})); + __m256i v${N}_${I} = _mm256_permute2f128_si256(v${N}${N+1}_${I%2}${I%2+2}, v${N+2}${N+3}_${I%2}${I%2+2}, _MM_SHUFFLE(0, ${I//2+2}, 0, ${I//2})); $if DATATYPE in ["QS8"]: $for N in range(0, NR, 4): $for I in range(0, 4): vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}_${I}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 4): + $for I in range(0, 4): + v${N}_${I} = xnn_packed2planar(&vacc${N}, v${N}_${I}, vmask, vone); $for I in range(0, 4): $for N in range(0, NR, 4): @@ -145,6 +191,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk $if DATATYPE in ["QS8"]: $for N in range(0, NR, 4): vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 4): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); $for N in range(0, NR, 4): _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); @@ -158,14 +208,11 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk if (k != 0) { assert(k >= 1 && k <= ${KR-1}); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - $for N in range(0, NR, 4): - __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N})); - v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+1})), 0x0C); - v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+2})), 0x30); - v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+3})), 0xC0); - v${N} = _mm256_and_si256(v${N}, vmask); + __m256i v${N} = _mm256_set1_epi64x((int64_t) safe_load_u64(w${N}, k)); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+1}, k)), 0x0C); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+2}, k)), 0x30); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+3}, k)), 0xC0); $for N in range(NR): w${N} += k; @@ -173,6 +220,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk $if DATATYPE in ["QS8"]: $for N in range(0, NR, 4): vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 4): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); $for N in range(0, NR, 4): _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); @@ -180,7 +231,7 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk out += ${NR*KR}; } - $if DATATYPE in ["QS8"]: + $if DATATYPE in ["QS8", "QS4"]: $for N in range(0, NR, 8): __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4}); vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0)); @@ -197,26 +248,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk } // NC remainder (1..${NR-1}) + // Same as main loop except bias is copied and w pointers are clamped if XNN_UNLIKELY(n != 0) { assert(n >= 1 && n <= ${NR-1}); - - $if DATATYPE in ["QS8"]: - ${BTYPE}* packed_b = (${BTYPE}*) out; - if XNN_LIKELY(b != NULL) { - size_t nb = n; - do { - *((${BTYPE}*) out) = *b++; - out += sizeof(${BTYPE}); - } while (--nb != 0); - } else { - size_t nb = n; - do { - *((${BTYPE}*) out) = 0; - out += sizeof(${BTYPE}); - } while (--nb != 0); - } - out += (${NR} - n) * sizeof(${BTYPE}); - + // Clamp weight pointers for NC remainder $for N in range(1, NR): const ${TYPE}* w${N} = w${N-1} + kc; $if N % 2 == 0: @@ -227,17 +262,70 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk if XNN_UNPREDICTABLE(n < ${N+1}) { w${N} = w${N-1}; } + + $if DATATYPE in ["QS8", "QS4"]: + ${BTYPE}* packed_b = (${BTYPE}*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + for (nb = 0; nb < n; ++nb) { + ((${BTYPE}*) out)[nb] = b[nb]; + } + b += n; + } else { + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256()); + } + out += ${NR} * sizeof(${BTYPE}); + $if PREFETCH: $for N in range(0, NR): - xnn_prefetch_to_l1((const int8_t*) w${N}); - xnn_prefetch_to_l1((const int8_t*) w${N} + 64); + $for OFFSET in range(0, 448, 64): + xnn_prefetch_to_l1((const int8_t*) w${N} + ${OFFSET}); - $if DATATYPE in ["QS8"]: + $if DATATYPE in ["QS8", "QS4"]: $for N in range(0, NR, 4): __m256i vacc${N} = _mm256_setzero_si256(); - // KC main loop multiple of ${NR}x${KR} size_t k = kc; + // KC main loop multiple of ${NR}x${4 * KR} + for (; k >= ${4 * KR}; k -= ${4 * KR}) { + $for N in range(NR): + $if DATATYPE in ["QS4"]: + const __m256i v${N}_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w${N}), vkernel_zero_point); // uint4 -> int4 + $else: + const __m256i v${N}_0123 = _mm256_loadu_si256((const __m256i*) w${N}); + + $for N in range(0, NR, 2): + const __m256i v${N}${N+1}_02 = _mm256_unpacklo_epi64(v${N}_0123, v${N+1}_0123); + const __m256i v${N}${N+1}_13 = _mm256_unpackhi_epi64(v${N}_0123, v${N+1}_0123); + + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $for N in range(0, NR, 4): + $for I in range(0, 4): + __m256i v${N}_${I} = _mm256_permute2f128_si256(v${N}${N+1}_${I%2}${I%2+2}, v${N+2}${N+3}_${I%2}${I%2+2}, _MM_SHUFFLE(0, ${I//2+2}, 0, ${I//2})); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + $for I in range(0, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}_${I}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 4): + $for I in range(0, 4): + v${N}_${I} = xnn_packed2planar(&vacc${N}, v${N}_${I}, vmask, vone); + + $for I in range(0, 4): + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${(I*NR + N)*KR}], v${N}_${I}); + + $for N in range(NR): + w${N} += ${4 * KR}; + out += ${4*NR*KR}; + } + + // KC main loop multiple of ${NR}x${KR} for (; k >= ${KR}; k -= ${KR}) { $for N in range(0, NR, 4): __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N})); @@ -251,6 +339,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk $if DATATYPE in ["QS8"]: $for N in range(0, NR, 4): vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 4): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); $for N in range(0, NR, 4): _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); @@ -264,14 +356,11 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk if (k != 0) { assert(k >= 1 && k <= ${KR-1}); - const __m256i vmask = _mm256_srli_epi64(_mm256_set1_epi32(-1), (8 - k) * sizeof(int8_t) * 8); - $for N in range(0, NR, 4): - __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N})); - v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+1})), 0x0C); - v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+2})), 0x30); - v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+3})), 0xC0); - v${N} = _mm256_and_si256(v${N}, vmask); + __m256i v${N} = _mm256_set1_epi64x((int64_t) safe_load_u64(w${N}, k)); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+1}, k)), 0x0C); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+2}, k)), 0x30); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) safe_load_u64(w${N+3}, k)), 0xC0); $for N in range(NR): w${N} += k; @@ -279,6 +368,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk $if DATATYPE in ["QS8"]: $for N in range(0, NR, 4): vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + $elif DATATYPE in ["QS4"]: + $for N in range(0, NR, 4): + v${N} = _mm256_xor_si256(v${N}, vkernel_zero_point); // uint4 -> int4 + v${N} = xnn_packed2planar(&vacc${N}, v${N}, vmask, vone); $for N in range(0, NR, 4): _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); @@ -286,7 +379,7 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk out += ${NR*KR}; } - $if DATATYPE in ["QS8"]: + $if DATATYPE in ["QS8", "QS4"]: $for N in range(0, NR, 8): __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4}); vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0)); @@ -298,9 +391,10 @@ void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_uk vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); $for N in range(0, NR, 8): _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); } - weights += nc * kc; + weights = (const ${WTYPE}*)((intptr_t) weights + nc * kc); } while (--g != 0); -} +} \ No newline at end of file diff --git a/src/x8-packw/kr-gio-avxvnni.c.in b/src/x8-packw/kr-gio-avxvnni.c.in new file mode 100644 index 000000000000..34be5be1f6b4 --- /dev/null +++ b/src/x8-packw/kr-gio-avxvnni.c.in @@ -0,0 +1,317 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR in [8, 16] +$assert KR == 8 +$assert DATATYPE in ["QS8", "X8"] +$assert TYPE in ["int8_t"] +$assert IZP in [0, 128] +$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +$if PREFETCH: + #include "xnnpack/prefetch.h" + +XNN_INLINE static uint64_t safe_load_u64(const void* src, size_t n) { + uint64_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < n; ++i) { + value |= (uint64_t)bytes[i] << (i * 8); + } + return value; +} + +$BTYPE = {"QS8": "int32_t", "X8": "uint32_t"}[DATATYPE] +$WTYPE = "int8_t" +$if DATATYPE in ["QS8"]: + $_MM256_DPBUSD_EPI32 = "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32" + $ISA = "avxvnni" if AVX == 2 else "avx256vnni" +$else: + $ISA = "avx2" if AVX == 2 else "avx256skx" +void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_gio_ukernel_x${NR}c${KR}__${ISA}${"_prfm" if PREFETCH else ""}( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t k_stride, + const ${WTYPE}* weights, + const ${BTYPE}* bias, + const void* scale, + ${WTYPE}* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + + $if DATATYPE in ["QS8"]: + const __m256i vone = _mm256_set1_epi8(1); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP}); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (;n >= ${NR}; n -= ${NR}) { + $if DATATYPE in ["QS8"]: + ${BTYPE}* packed_b = (${BTYPE}*) out; + if XNN_LIKELY(b != NULL) { + $for N in range(0, NR, 8): + const __m256i vb${N} = _mm256_loadu_si256((const __m256i*) (b + ${N})); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), vb${N}); + b += ${NR}; + } else { + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256()); + } + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + k_stride; + $if PREFETCH: + $for N in range(0, NR): + $for OFFSET in range(0, 448, 64): + xnn_prefetch_to_l1((const int8_t*) w${N} + ${OFFSET}); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + __m256i vacc${N} = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of ${NR}x${KR} + for (; k >= ${KR}; k -= ${KR}) { + $for K in range(KR): + __m128i v${K}x${ABC[0:NR]} = _mm_loadu_si64(w${K}); + $if PREFETCH: + $for K in range(0, KR): + xnn_prefetch_to_l1((const int8_t*) w${K} + 448); + + $for K in range(0, KR, 2): + __m128i v${ABC[K:K+2]}x${ABC[0:NR]} = _mm_unpacklo_epi8(v${K}x${ABC[0:NR]}, v${K+1}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[K:K+4]}x${ABC[0:4]} = _mm_unpacklo_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + __m128i v${ABC[K:K+4]}x${ABC[4:8]} = _mm_unpackhi_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[0:NR]}x${ABC[K:K+2]} = _mm_unpacklo_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + __m128i v${ABC[0:NR]}x${ABC[K+2:K+4]} = _mm_unpackhi_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_inserti128_si256(_mm256_castsi128_si256(v${ABC[0:NR]}x${ABC[N:N+2]}), v${ABC[0:NR]}x${ABC[N+2:N+4]}, 1); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += ${KR} * k_stride; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + assert(k >= 1 && k <= ${KR-1}); + + __m128i vzero = _mm_setzero_si128(); + __m128i v0x${ABC[0:NR]} = _mm_loadu_si64(w0); + $for K in range(1, KR): + __m128i v${K}x${ABC[0:NR]} = vzero; + if (${K} < k) { + v${K}x${ABC[0:NR]} = _mm_loadu_si64(w${K}); + } + + $for K in range(0, KR, 2): + __m128i v${ABC[K:K+2]}x${ABC[0:NR]} = _mm_unpacklo_epi8(v${K}x${ABC[0:NR]}, v${K+1}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[K:K+4]}x${ABC[0:4]} = _mm_unpacklo_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + __m128i v${ABC[K:K+4]}x${ABC[4:8]} = _mm_unpackhi_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[0:NR]}x${ABC[K:K+2]} = _mm_unpacklo_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + __m128i v${ABC[0:NR]}x${ABC[K+2:K+4]} = _mm_unpackhi_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_inserti128_si256(_mm256_castsi128_si256(v${ABC[0:NR]}x${ABC[N:N+2]}), v${ABC[0:NR]}x${ABC[N+2:N+4]}, 1); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += k * k_stride; + out += ${NR*KR}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4}); + vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0)); + $for N in range(0, NR, 8): + vksum${N} = _mm256_mullo_epi32(vksum${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w0 - kc * k_stride + ${NR}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + $if DATATYPE in ["QS8"]: + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = *b++; + $else: + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } else { + size_t nb = n; + do { + $if BTYPE == TYPE: + *out++ = 0; + $else: + *((${BTYPE}*) out) = 0; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } + $if BTYPE == TYPE: + out += (${NR} - n); + $else: + out += (${NR} - n) * sizeof(${BTYPE}); + + $if NR > 2: + $for K in range(1, KR): + const ${TYPE}* w${K} = w${K-1} + k_stride; + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + __m256i vacc${N} = _mm256_setzero_si256(); + + size_t k = kc; + + // KC main loop multiple of ${NR}x${KR} + for (; k >= ${KR}; k -= ${KR}) { + $for K in range(KR): + __m128i v${K}x${ABC[0:NR]} = _mm_set1_epi64x((int64_t) safe_load_u64(w${K}, n)); + + $for K in range(0, KR, 2): + __m128i v${ABC[K:K+2]}x${ABC[0:NR]} = _mm_unpacklo_epi8(v${K}x${ABC[0:NR]}, v${K+1}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[K:K+4]}x${ABC[0:4]} = _mm_unpacklo_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + __m128i v${ABC[K:K+4]}x${ABC[4:8]} = _mm_unpackhi_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[0:NR]}x${ABC[K:K+2]} = _mm_unpacklo_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + __m128i v${ABC[0:NR]}x${ABC[K+2:K+4]} = _mm_unpackhi_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_inserti128_si256(_mm256_castsi128_si256(v${ABC[0:NR]}x${ABC[N:N+2]}), v${ABC[0:NR]}x${ABC[N+2:N+4]}, 1); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += ${KR} * k_stride; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + assert(k >= 1 && k <= ${KR-1}); + + __m128i vzero = _mm_setzero_si128(); + __m128i v0x${ABC[0:NR]} = _mm_set1_epi64x((int64_t) safe_load_u64(w0, n)); + $for K in range(1, KR): + __m128i v${K}x${ABC[0:NR]} = vzero; + if (${K} < k) { + v${K}x${ABC[0:NR]} = _mm_set1_epi64x((int64_t) safe_load_u64(w${K}, n)); + } + + $for K in range(0, KR, 2): + __m128i v${ABC[K:K+2]}x${ABC[0:NR]} = _mm_unpacklo_epi8(v${K}x${ABC[0:NR]}, v${K+1}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[K:K+4]}x${ABC[0:4]} = _mm_unpacklo_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + __m128i v${ABC[K:K+4]}x${ABC[4:8]} = _mm_unpackhi_epi16(v${ABC[K:K+2]}x${ABC[0:NR]}, v${ABC[K+2:K+4]}x${ABC[0:NR]}); + + $for K in range(0, KR, 4): + __m128i v${ABC[0:NR]}x${ABC[K:K+2]} = _mm_unpacklo_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + __m128i v${ABC[0:NR]}x${ABC[K+2:K+4]} = _mm_unpackhi_epi32(v0123x${ABC[(K//4)*4:(K//4)*4 + 4]}, v4567x${ABC[(K//4)*4:(K//4)*4 + 4]}); + + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_inserti128_si256(_mm256_castsi128_si256(v${ABC[0:NR]}x${ABC[N:N+2]}), v${ABC[0:NR]}x${ABC[N+2:N+4]}, 1); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += k * k_stride; + out += ${NR*KR}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4}); + vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0)); + $for N in range(0, NR, 8): + vksum${N} = _mm256_mullo_epi32(vksum${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w0 - kc * k_stride + ${NR}; + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x8-packw/kr-wasmdot.c.in b/src/x8-packw/kr-wasmdot.c.in index 4b5bd727730a..d0a4100d1945 100644 --- a/src/x8-packw/kr-wasmdot.c.in +++ b/src/x8-packw/kr-wasmdot.c.in @@ -14,6 +14,16 @@ $assert IZP in [0, 128] #include "xnnpack/packw.h" +XNN_INLINE static v128_t safe_v128_load64_splat(const void* address, size_t n) { + assert(n >= 1 && n <= sizeof(uint64_t)); + const uint8_t* bytes = (const uint8_t*) address; + uint64_t value = (uint64_t) bytes[0]; + for (size_t i = 1; i < n; ++i) { + value |= (uint64_t) bytes[i] << (i * 8); + } + + return wasm_u64x2_splat(value); +} $ABC = "012345678" $BTYPE = {"int8_t": "uint32_t"}[TYPE] @@ -30,7 +40,7 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K const void* scale, ${WTYPE}* packed_weights, size_t extra_bytes, - const void* params) XNN_OOB_READS + const void* params) { assert(g != 0); assert(nc != 0); @@ -115,17 +125,18 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K out += ${NR*KR}; } + // Load ealier to avoid unexpected rescheduling. + v128_t vpack0123 = wasm_v128_load(packed_b); + v128_t vpack4567 = wasm_v128_load(packed_b + 4); + // KC remainder 1..KR-1 if (k != 0) { assert(k >= 1 && k <= ${KR-1}); - const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8); - $for N in range(0, NR, 2): - const v128_t v${N} = wasm_v128_load64_splat(w${N}); - const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1}); - v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3); - v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask); + const v128_t v${N} = safe_v128_load64_splat(w${N}, k); + const v128_t v${N+1} = safe_v128_load64_splat(w${N+1}, k); + const v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3); $for N in range(0, NR, 2): vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]}); @@ -144,9 +155,6 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint); vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint); - v128_t vpack0123 = wasm_v128_load(packed_b); - v128_t vpack4567 = wasm_v128_load(packed_b + 4); - wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123)); wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567)); @@ -207,17 +215,18 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K out += ${NR*KR}; } + // Load ealier to avoid unexpected rescheduling. + v128_t vpack0123 = wasm_v128_load(packed_b); + v128_t vpack4567 = wasm_v128_load(packed_b + 4); + // KC remainder of 1..${KR-1} if (k != 0) { assert(k >= 1 && k <= ${KR-1}); - const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8); - $for N in range(0, NR, 2): - const v128_t v${N} = wasm_v128_load64_splat(w${N}); - const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1}); - v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3); - v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask); + const v128_t v${N} = safe_v128_load64_splat(w${N}, k); + const v128_t v${N+1} = safe_v128_load64_splat(w${N+1}, k); + const v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3); $for N in range(0, NR, 2): vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]}); @@ -234,9 +243,6 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint); vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint); - v128_t vpack0123 = wasm_v128_load(packed_b); - v128_t vpack4567 = wasm_v128_load(packed_b + 4); - wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123)); wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567)); diff --git a/src/x8-zip/x8-zip-x2-neon.c b/src/x8-zip/x8-zip-x2-neon.c deleted file mode 100644 index 19b9c97f6637..000000000000 --- a/src/x8-zip/x8-zip-x2-neon.c +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x2_ukernel__neon( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - uint8_t* o = output; - - if (n >= 8) { - do { - uint8x8x2_t vxy; - vxy.val[0] = vld1_u8(x); x += 8; - vxy.val[1] = vld1_u8(y); y += 8; - vst2_u8(o, vxy); o += 16;; - n -= 8; - } while (n >= 8); - if (n != 0) { - const size_t address_increment = n - 8; - uint8x8x2_t vxy; - vxy.val[0] = vld1_u8((const uint8_t*) ((uintptr_t) x + address_increment)); - vxy.val[1] = vld1_u8((const uint8_t*) ((uintptr_t) y + address_increment)); - vst2_u8((uint8_t*) ((uintptr_t) o + address_increment * 2), vxy); - } - } else { - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - o[0] = vx; - o[1] = vy; - o += 2; - } while (--n != 0); - } -} diff --git a/src/x8-zip/x8-zip-x2-scalar.c b/src/x8-zip/x8-zip-x2-scalar.c deleted file mode 100644 index a0ffcd24ce4a..000000000000 --- a/src/x8-zip/x8-zip-x2-scalar.c +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x2_ukernel__scalar( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - assert(n != 0); - - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - uint8_t* o = output; - - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - o[0] = vx; - o[1] = vy; - o += 2; - - n -= sizeof(uint8_t); - } while (n != 0); -} diff --git a/src/x8-zip/x8-zip-x2-sse2.c b/src/x8-zip/x8-zip-x2-sse2.c deleted file mode 100644 index 640832fa8744..000000000000 --- a/src/x8-zip/x8-zip-x2-sse2.c +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x2_ukernel__sse2( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - uint8_t* o = output; - - if (n >= 16) { - do { - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 16; - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 16; - const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); - _mm_storeu_si128((__m128i*) o, vxy_lo); - _mm_storeu_si128((__m128i*) (o + 16), vxy_hi); - o = (void*) ((uintptr_t) o + 32); - n -= 16; - } while (n >= 16); - if (n != 0) { - const size_t address_increment = n - 16; - const __m128i vx = _mm_loadu_si128((const __m128i*) ((uintptr_t) x + address_increment)); - const __m128i vy = _mm_loadu_si128((const __m128i*) ((uintptr_t) y + address_increment)); - const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); - o = (void*) ((uintptr_t) o + address_increment * 2); - _mm_storeu_si128((__m128i*) o, vxy_lo); - _mm_storeu_si128((__m128i*) o + 1, vxy_hi); - } - } else { - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - o[0] = vx; - o[1] = vy; - o += 2; - } while (--n != 0); - } -} diff --git a/src/x8-zip/x8-zip-x3-neon.c b/src/x8-zip/x8-zip-x3-neon.c deleted file mode 100644 index 6e947ebdee2d..000000000000 --- a/src/x8-zip/x8-zip-x3-neon.c +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x3_ukernel__neon( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n); - uint8_t* o = output; - - if (n >= 8) { - do { - uint8x8x3_t vxyz; - vxyz.val[0] = vld1_u8(x); x += 8; - vxyz.val[1] = vld1_u8(y); y += 8; - vxyz.val[2] = vld1_u8(z); z += 8; - vst3_u8(o, vxyz); o += 24; - n -= 8; - } while (n >= 8); - if (n != 0) { - const size_t address_increment = n - 8; - uint8x8x3_t vxyz; - vxyz.val[0] = vld1_u8((const uint8_t*) ((uintptr_t) x + address_increment)); - vxyz.val[1] = vld1_u8((const uint8_t*) ((uintptr_t) y + address_increment)); - vxyz.val[2] = vld1_u8((const uint8_t*) ((uintptr_t) z + address_increment)); - vst3_u8((uint8_t*) ((uintptr_t) o + address_increment * 3), vxyz); - } - } else { - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - const uint8_t vz = *z++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o += 3; - } while (--n != 0); - } -} diff --git a/src/x8-zip/x8-zip-x3-scalar.c b/src/x8-zip/x8-zip-x3-scalar.c deleted file mode 100644 index a5768086de61..000000000000 --- a/src/x8-zip/x8-zip-x3-scalar.c +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x3_ukernel__scalar( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n); - uint8_t* o = output; - - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - const uint8_t vz = *z++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o += 3; - - n -= sizeof(uint8_t); - } while (n != 0); -} diff --git a/src/x8-zip/x8-zip-x3-sse2.c b/src/x8-zip/x8-zip-x3-sse2.c deleted file mode 100644 index 4ac5cfd76ee9..000000000000 --- a/src/x8-zip/x8-zip-x3-sse2.c +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x3_ukernel__sse2( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n); - uint8_t* o = output; - - if (n >= 16) { - const __m128i vmask0x00FF00FF = _mm_set1_epi16(0x00FF); - const __m128i vmask0x0000FFFF = _mm_set1_epi32(0x0000FFFF); - do { - // vx = ( x15, x14, x13, x12, x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1, x0 ) - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 16; - // vy = ( y15, y14, y13, y12, y11, y10, y9, y8, y7, y6, y5, y4, y3, y2, y1, y0 ) - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 16; - // vz = ( z15, z14, z13, z12, z11, z10, z9, z8, z7, z6, z5, z4, z3, z2, z1, z0 ) - const __m128i vz = _mm_loadu_si128((const __m128i*) z); - z += 16; - - // vxeye = ( y14, x14, y12, x12, y10, x10, y8, x8, y6, x6, y4, x4, y2, x2, y0, x0 ) - const __m128i vxeye = _mm_or_si128(_mm_and_si128(vx, vmask0x00FF00FF), _mm_slli_epi16(vy, 8)); - // vyozo = ( z15, y15, z13, y13, z11, y11, z9, y9, z7, y7, z5, y5, z3, y3, z1, y1 ) - const __m128i vyozo = _mm_or_si128(_mm_andnot_si128(vmask0x00FF00FF, vz), _mm_srli_epi16(vy, 8)); - // vzoxo = ( x15, z14, x13, z12, x11, z10, x9, z8, x7, z6, x5, z4, x3, z2, x1, z0 ) - const __m128i vzexo = _mm_or_si128(_mm_and_si128(vz, vmask0x00FF00FF), _mm_andnot_si128(vmask0x00FF00FF, vx)); - - // vxeyezexo = ( x13, z12, y12, x12, x9, z8, y8, x8, x5, z4, y4, x4, x1, z0, y0, x0 ) - const __m128i vxeyezexo = _mm_or_si128(_mm_and_si128(vxeye, vmask0x0000FFFF), _mm_slli_epi32(vzexo, 16)); - // vyozoxeye = ( y14, x14, z13, y13, y10, x10, z9, y9, y6, x6, z5, y5, y2, x2, z1, y1 ) - const __m128i vyozoxeye = _mm_or_si128(_mm_and_si128(vyozo, vmask0x0000FFFF), _mm_andnot_si128(vmask0x0000FFFF, vxeye)); - // vzexoyozo = ( z15, y15, x15, z14, z11, y11, x11, z10, z7, y7, x7, z6, z3, y3, x3, z2 ) - const __m128i vzexoyozo = _mm_or_si128(_mm_andnot_si128(vmask0x0000FFFF, vyozo), _mm_srli_epi32(vzexo, 16)); - - // vtemp0 = ( x13, z12, y12, x12, x5, z4, y4, x4, z11, y11, x11, z10, z3, y3, x3, z2 ) - const __m128i vtemp0 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vzexoyozo), _mm_castsi128_ps(vxeyezexo), _MM_SHUFFLE(3, 1, 2, 0))); - // vtemp1 = ( y10, x10, z9, y9, y2, x2, z1, y1, x9, z8, y8, x8, x1, z0, y0, x0 ) - const __m128i vtemp1 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vxeyezexo), _mm_castsi128_ps(vyozoxeye), _MM_SHUFFLE(2, 0, 2, 0))); - // vtemp2 = ( z15, y15, x15, z14, z7, y7, x7, z6, y14, x14, z13, y13, y6, x6, z5, y5 ) - const __m128i vtemp2 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vyozoxeye), _mm_castsi128_ps(vzexoyozo), _MM_SHUFFLE(3, 1, 3, 1))); - - // vxyz0 = ( x5, z4, y4, x4, z3, y3, x3, z2, y2, x2, z1, y1, x1, z0, y0, x0 ) - const __m128i vxyz0 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vtemp1), _mm_castsi128_ps(vtemp0), _MM_SHUFFLE(2, 0, 2, 0))); - // vxyz1 = ( y10, x10, z9, y9, x9, z8, y8, x8, z7, y7, x7, z6, y6, x6, z5, y5 ) - const __m128i vxyz1 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vtemp2), _mm_castsi128_ps(vtemp1), _MM_SHUFFLE(3, 1, 2, 0))); - // vxyz2 = ( z15, y15, x15, z14, y14, x14, z13, y13, x13, z12, y12, x12, z11, y11, x11, z10 ) - const __m128i vxyz2 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vtemp0), _mm_castsi128_ps(vtemp2), _MM_SHUFFLE(3, 1, 3, 1))); - - _mm_storeu_si128((__m128i*) o, vxyz0); - _mm_storeu_si128((__m128i*) o + 1, vxyz1); - _mm_storeu_si128((__m128i*) o + 2, vxyz2); - o += 48; - n -= 16; - } while (n >= 16); - if (n != 0) { - const size_t address_increment = n - 16; - // vx = ( x15, x14, x13, x12, x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1, x0 ) - const __m128i vx = _mm_loadu_si128((const __m128i*) ((uintptr_t) x + address_increment)); - // vy = ( y15, y14, y13, y12, y11, y10, y9, y8, y7, y6, y5, y4, y3, y2, y1, y0 ) - const __m128i vy = _mm_loadu_si128((const __m128i*) ((uintptr_t) y + address_increment)); - // vz = ( z15, z14, z13, z12, z11, z10, z9, z8, z7, z6, z5, z4, z3, z2, z1, z0 ) - const __m128i vz = _mm_loadu_si128((const __m128i*) ((uintptr_t) z + address_increment)); - - // vxeye = ( y14, x14, y12, x12, y10, x10, y8, x8, y6, x6, y4, x4, y2, x2, y0, x0 ) - const __m128i vxeye = _mm_or_si128(_mm_and_si128(vx, vmask0x00FF00FF), _mm_slli_epi16(vy, 8)); - // vyozo = ( z15, y15, z13, y13, z11, y11, z9, y9, z7, y7, z5, y5, z3, y3, z1, y1 ) - const __m128i vyozo = _mm_or_si128(_mm_andnot_si128(vmask0x00FF00FF, vz), _mm_srli_epi16(vy, 8)); - // vzoxo = ( x15, z14, x13, z12, x11, z10, x9, z8, x7, z6, x5, z4, x3, z2, x1, z0 ) - const __m128i vzexo = _mm_or_si128(_mm_and_si128(vz, vmask0x00FF00FF), _mm_andnot_si128(vmask0x00FF00FF, vx)); - - // vxeyezexo = ( x13, z12, y12, x12, x9, z8, y8, x8, x5, z4, y4, x4, x1, z0, y0, x0 ) - const __m128i vxeyezexo = _mm_or_si128(_mm_and_si128(vxeye, vmask0x0000FFFF), _mm_slli_epi32(vzexo, 16)); - // vyozoxeye = ( y14, x14, z13, y13, y10, x10, z9, y9, y6, x6, z5, y5, y2, x2, z1, y1 ) - const __m128i vyozoxeye = _mm_or_si128(_mm_and_si128(vyozo, vmask0x0000FFFF), _mm_andnot_si128(vmask0x0000FFFF, vxeye)); - // vzexoyozo = ( z15, y15, x15, z14, z11, y11, x11, z10, z7, y7, x7, z6, z3, y3, x3, z2 ) - const __m128i vzexoyozo = _mm_or_si128(_mm_andnot_si128(vmask0x0000FFFF, vyozo), _mm_srli_epi32(vzexo, 16)); - - // vtemp0 = ( x13, z12, y12, x12, x5, z4, y4, x4, z11, y11, x11, z10, z3, y3, x3, z2 ) - const __m128i vtemp0 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vzexoyozo), _mm_castsi128_ps(vxeyezexo), _MM_SHUFFLE(3, 1, 2, 0))); - // vtemp1 = ( y10, x10, z9, y9, y2, x2, z1, y1, x9, z8, y8, x8, x1, z0, y0, x0 ) - const __m128i vtemp1 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vxeyezexo), _mm_castsi128_ps(vyozoxeye), _MM_SHUFFLE(2, 0, 2, 0))); - // vtemp2 = ( z15, y15, x15, z14, z7, y7, x7, z6, y14, x14, z13, y13, y6, x6, z5, y5 ) - const __m128i vtemp2 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vyozoxeye), _mm_castsi128_ps(vzexoyozo), _MM_SHUFFLE(3, 1, 3, 1))); - - // vxyz0 = ( x5, z4, y4, x4, z3, y3, x3, z2, y2, x2, z1, y1, x1, z0, y0, x0 ) - const __m128i vxyz0 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vtemp1), _mm_castsi128_ps(vtemp0), _MM_SHUFFLE(2, 0, 2, 0))); - // vxyz1 = ( y10, x10, z9, y9, x9, z8, y8, x8, z7, y7, x7, z6, y6, x6, z5, y5 ) - const __m128i vxyz1 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vtemp2), _mm_castsi128_ps(vtemp1), _MM_SHUFFLE(3, 1, 2, 0))); - // vxyz2 = ( z15, y15, x15, z14, y14, x14, z13, y13, x13, z12, y12, x12, z11, y11, x11, z10 ) - const __m128i vxyz2 = _mm_castps_si128( - _mm_shuffle_ps(_mm_castsi128_ps(vtemp0), _mm_castsi128_ps(vtemp2), _MM_SHUFFLE(3, 1, 3, 1))); - - o = (uint8_t*) ((uintptr_t) o + address_increment * 3); - _mm_storeu_si128((__m128i*) o, vxyz0); - _mm_storeu_si128((__m128i*) o + 1, vxyz1); - _mm_storeu_si128((__m128i*) o + 2, vxyz2); - } - } else { - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - const uint8_t vz = *z++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o += 3; - } while (--n != 0); - } -} diff --git a/src/x8-zip/x8-zip-x4-neon.c b/src/x8-zip/x8-zip-x4-neon.c deleted file mode 100644 index 158b325f6535..000000000000 --- a/src/x8-zip/x8-zip-x4-neon.c +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x4_ukernel__neon( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n); - const uint8_t* w = (const uint8_t*) ((uintptr_t) z + n); - uint8_t* o = output; - - if (n >= 8) { - do { - uint8x8x4_t vxyzw; - vxyzw.val[0] = vld1_u8(x); x += 8; - vxyzw.val[1] = vld1_u8(y); y += 8; - vxyzw.val[2] = vld1_u8(z); z += 8; - vxyzw.val[3] = vld1_u8(w); w += 8; - vst4_u8(o, vxyzw); o += 32; - n -= 8; - } while (n >= 8); - if (n != 0) { - const size_t address_increment = n - 8; - uint8x8x4_t vxyzw; - vxyzw.val[0] = vld1_u8((const uint8_t*) ((uintptr_t) x + address_increment)); - vxyzw.val[1] = vld1_u8((const uint8_t*) ((uintptr_t) y + address_increment)); - vxyzw.val[2] = vld1_u8((const uint8_t*) ((uintptr_t) z + address_increment)); - vxyzw.val[3] = vld1_u8((const uint8_t*) ((uintptr_t) w + address_increment)); - vst4_u8((uint8_t*) ((uintptr_t) o + address_increment * 4), vxyzw); - } - } else { - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - const uint8_t vz = *z++; - const uint8_t vw = *w++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - o += 4; - } while (--n != 0); - } -} diff --git a/src/x8-zip/x8-zip-x4-scalar.c b/src/x8-zip/x8-zip-x4-scalar.c deleted file mode 100644 index bfce3071ab33..000000000000 --- a/src/x8-zip/x8-zip-x4-scalar.c +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x4_ukernel__scalar( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - assert(n != 0); - - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n); - const uint8_t* w = (const uint8_t*) ((uintptr_t) z + n); - uint8_t* o = output; - - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - const uint8_t vz = *z++; - const uint8_t vw = *w++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - o += 4; - - n -= sizeof(uint8_t); - } while (n != 0); -} diff --git a/src/x8-zip/x8-zip-x4-sse2.c b/src/x8-zip/x8-zip-x4-sse2.c deleted file mode 100644 index b00dfc8901e0..000000000000 --- a/src/x8-zip/x8-zip-x4-sse2.c +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_x4_ukernel__sse2( - size_t n, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* x = input; - const uint8_t* y = (const uint8_t*) ((uintptr_t) x + n); - const uint8_t* z = (const uint8_t*) ((uintptr_t) y + n); - const uint8_t* w = (const uint8_t*) ((uintptr_t) z + n); - uint8_t* o = output; - - if (n >= 16) { - do { - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 16; - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 16; - const __m128i vz = _mm_loadu_si128((const __m128i*) z); - z += 16; - const __m128i vw = _mm_loadu_si128((const __m128i*) w); - w += 16; - const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); - const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw); - const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw); - const __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo); - const __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo); - const __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi); - const __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi); - _mm_storeu_si128((__m128i*) o, vxyzw0); - _mm_storeu_si128((__m128i*) o + 1, vxyzw1); - _mm_storeu_si128((__m128i*) o + 2, vxyzw2); - _mm_storeu_si128((__m128i*) o + 3, vxyzw3); - o = (void*) ((uintptr_t) o + 64); - n -= 16; - } while (n >= 16); - if (n != 0) { - const size_t address_increment = n - 16; - const __m128i vx = _mm_loadu_si128((const __m128i*) ((uintptr_t) x + address_increment)); - const __m128i vy = _mm_loadu_si128((const __m128i*) ((uintptr_t) y + address_increment)); - const __m128i vz = _mm_loadu_si128((const __m128i*) ((uintptr_t) z + address_increment)); - const __m128i vw = _mm_loadu_si128((const __m128i*) ((uintptr_t) w + address_increment)); - const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); - const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw); - const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw); - const __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo); - const __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo); - const __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi); - const __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi); - o = (void*) ((uintptr_t) o + address_increment * 4); - _mm_storeu_si128((__m128i*) o, vxyzw0); - _mm_storeu_si128((__m128i*) o + 1, vxyzw1); - _mm_storeu_si128((__m128i*) o + 2, vxyzw2); - _mm_storeu_si128((__m128i*) o + 3, vxyzw3); - } - } else { - do { - const uint8_t vx = *x++; - const uint8_t vy = *y++; - const uint8_t vz = *z++; - const uint8_t vw = *w++; - o[0] = vx; - o[1] = vy; - o[2] = vz; - o[3] = vw; - o += 4; - } while (--n != 0); - } -} diff --git a/src/x8-zip/x8-zip-xm-neon.c b/src/x8-zip/x8-zip-xm-neon.c deleted file mode 100644 index 7839d0826d8c..000000000000 --- a/src/x8-zip/x8-zip-xm-neon.c +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_xm_ukernel__neon( - size_t n, - size_t m, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* w = input; - const size_t input_increment = n * 3; - const size_t output_increment = 4 - m * n; - const uint8_t* last_input = w + n * (m - 1); - uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4)); - - if (n >= 8) { - for (size_t i = 0; i < m; i += 4) { - size_t k = n; - w = (const uint8_t*) ((uintptr_t) w + input_increment); - if (w >= last_input) { - w = last_input; - } - const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n); - const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n); - const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n); - while (k >= 8) { - const uint8x8_t vx = vld1_u8(x); x += 8; - const uint8x8_t vy = vld1_u8(y); y += 8; - const uint8x8_t vz = vld1_u8(z); z += 8; - const uint8x8_t vw = vld1_u8(w); w += 8; - - const uint8x8x2_t vxy = vzip_u8(vx, vy); - const uint8x8x2_t vzw = vzip_u8(vz, vw); - const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0])); - const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1])); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 1); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 1); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 1); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 1); - output = (uint8_t*) ((uintptr_t) output + m); - - k -= 8; - } - if (k != 0) { - const size_t address_increment = k - 8; - x = (const uint8_t*) ((uintptr_t) x + address_increment); - y = (const uint8_t*) ((uintptr_t) y + address_increment); - z = (const uint8_t*) ((uintptr_t) z + address_increment); - w = (const uint8_t*) ((uintptr_t) w + address_increment); - const int64x1_t vshift = vmov_n_s64(8 * address_increment); - - const uint64x1_t vx = vshl_u64(vreinterpret_u64_u8(vld1_u8(x)), vshift); - const uint64x1_t vy = vshl_u64(vreinterpret_u64_u8(vld1_u8(y)), vshift); - const uint64x1_t vz = vshl_u64(vreinterpret_u64_u8(vld1_u8(z)), vshift); - const uint64x1_t vw = vshl_u64(vreinterpret_u64_u8(vld1_u8(w)), vshift); w += 8; - const uint8x8x2_t vxy = vzip_u8(vreinterpret_u8_u64(vx), vreinterpret_u8_u64(vy)); - const uint8x8x2_t vzw = vzip_u8(vreinterpret_u8_u64(vz), vreinterpret_u8_u64(vw)); - const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0])); - const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1])); - - uint32x2_t vxyzw0 = vreinterpret_u32_u16(vxyzw_lo.val[0]); - uint32x2_t vxyzw1 = vreinterpret_u32_u16(vxyzw_lo.val[1]); - uint32x2_t vxyzw2 = vreinterpret_u32_u16(vxyzw_hi.val[0]); - uint32x2_t vxyzw3 = vreinterpret_u32_u16(vxyzw_hi.val[1]); - - if (k & 4) { - vst1_lane_u32((void*) output, vxyzw0, 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vxyzw0, 1); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vxyzw1, 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vxyzw1, 1); - output = (uint8_t*) ((uintptr_t) output + m); - - vxyzw0 = vxyzw2; - vxyzw1 = vxyzw3; - } - - if (k & 2) { - vst1_lane_u32((void*) output, vxyzw0, 0); - output = (uint8_t*) ((uintptr_t) output + m); - - vst1_lane_u32((void*) output, vxyzw0, 1); - output = (uint8_t*) ((uintptr_t) output + m); - - vxyzw0 = vxyzw1; - } - if (k & 1) { - vst1_lane_u32((void*) output, vxyzw0, 0); - output = (uint8_t*) ((uintptr_t) output + m); - } - } - output = (uint8_t*) ((uintptr_t) output + output_increment); - if (output > last_output) { - output = last_output; - } - } - } else { - const uint8_t* i = input; - uint8_t* o = output; - size_t k = n; - do { - size_t l = m; - const uint8_t* ii = i++; - do { - *o++ = *ii; - ii += n; - } while (--l != 0); - } while (--k != 0); - } -} diff --git a/src/x8-zip/x8-zip-xm-scalar.c b/src/x8-zip/x8-zip-xm-scalar.c deleted file mode 100644 index 4d2a4e553333..000000000000 --- a/src/x8-zip/x8-zip-xm-scalar.c +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" - - -void xnn_x8_zip_xm_ukernel__scalar( - size_t n, - size_t m, - const uint8_t* input, - uint8_t* output) -{ - assert(n != 0); - assert(m >= 4); - - size_t k = n; - do { - size_t l = m; - const uint8_t* input_column = input++; - do { - *output++ = *input_column; - input_column = (uint8_t*) ((uintptr_t) input_column + n); - } while (--l != 0); - k -= sizeof(uint8_t); - } while (k != 0); -} diff --git a/src/x8-zip/x8-zip-xm-sse2.c b/src/x8-zip/x8-zip-xm-sse2.c deleted file mode 100644 index 1309639dd59b..000000000000 --- a/src/x8-zip/x8-zip-xm-sse2.c +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include "xnnpack/zip.h" -#include "xnnpack/unaligned.h" - - -void xnn_x8_zip_xm_ukernel__sse2( - size_t n, - size_t m, - const uint8_t* input, - uint8_t* output) -{ - const uint8_t* w = input; - const size_t input_increment = n * 3; - const size_t output_increment = 4 - m * n; - const uint8_t* last_input = w + n * (m - 1); - uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4)); - - if (n >= 8) { - for (size_t i = 0; i < m; i += 4) { - size_t k = n; - w = (const uint8_t*) ((uintptr_t) w + input_increment); - if (w >= last_input) { - w = last_input; - } - const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n); - const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n); - const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n); - while (k >= 16) { - const __m128i vx = _mm_loadu_si128((const __m128i*) x); - x += 16; - const __m128i vy = _mm_loadu_si128((const __m128i*) y); - y += 16; - const __m128i vz = _mm_loadu_si128((const __m128i*) z); - z += 16; - const __m128i vw = _mm_loadu_si128((const __m128i*) w); - w += 16; - const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); - const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); - const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw); - const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw); - __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo); - __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo); - __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi); - __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi); - - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw1 = _mm_unpackhi_epi64(vxyzw1, vxyzw1); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw2)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw2 = _mm_shufflelo_epi16(vxyzw2, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw2)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw2 = _mm_unpackhi_epi64(vxyzw2, vxyzw2); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw2)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw2 = _mm_shufflelo_epi16(vxyzw2, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw2)); - output = (uint8_t*) ((uintptr_t) output + m); - - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw3)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw3 = _mm_shufflelo_epi16(vxyzw3, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw3)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw3 = _mm_unpackhi_epi64(vxyzw3, vxyzw3); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw3)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw3 = _mm_shufflelo_epi16(vxyzw3, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw3)); - output = (uint8_t*) ((uintptr_t) output + m); - k -= 16; - }; - if (k >= 8) { - const __m128i vx = _mm_loadl_epi64((const __m128i*) x); - x += 8; - const __m128i vy = _mm_loadl_epi64((const __m128i*) y); - y += 8; - const __m128i vz = _mm_loadl_epi64((const __m128i*) z); - z += 8; - const __m128i vw = _mm_loadl_epi64((const __m128i*) w); - w += 8; - const __m128i vxy = _mm_unpacklo_epi8(vx, vy); - const __m128i vzw = _mm_unpacklo_epi8(vz, vw); - __m128i vxyzw0 = _mm_unpacklo_epi16(vxy, vzw); - __m128i vxyzw1 = _mm_unpackhi_epi16(vxy, vzw); - - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw1 = _mm_unpackhi_epi64(vxyzw1, vxyzw1); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw1)); - output = (uint8_t*) ((uintptr_t) output + m); - k -= 8; - } - if (k != 0) { - const size_t address_decrement = 8 - k; - x -= address_decrement; - y -= address_decrement; - z -= address_decrement; - w -= address_decrement; - const __m128i vshift = _mm_cvtsi32_si128((int) address_decrement * 8); - - const __m128i vx = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) x), vshift); - const __m128i vy = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) y), vshift); - const __m128i vz = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) z), vshift); - const __m128i vw = _mm_srl_epi64(_mm_loadl_epi64((const __m128i*) w), vshift); - w += 8; - const __m128i vxy = _mm_unpacklo_epi8(vx, vy); - const __m128i vzw = _mm_unpacklo_epi8(vz, vw); - __m128i vxyzw0 = _mm_unpacklo_epi16(vxy, vzw); - __m128i vxyzw1 = _mm_unpackhi_epi16(vxy, vzw); - - if (k & 4) { - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = vxyzw1; - } - - if (k & 2) { - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); - } - if (k & 1) { - unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vxyzw0)); - output = (uint8_t*) ((uintptr_t) output + m); - } - } - output = (uint8_t*) ((uintptr_t) output + output_increment); - if (output > last_output) { - output = last_output; - } - } - } else { - const uint8_t* i = input; - uint8_t* o = output; - size_t k = n; - do { - size_t l = m; - const uint8_t* ii = i++; - do { - *o++ = *ii; - ii += n; - } while (--l != 0); - } while (--k != 0); - } -} diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 835e7d13e2eb..6fde480accd3 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -350,6 +350,11 @@ struct gemm_context { size_t mr_block_size, size_t nr_block_size); + XNN_PRIVATE void xnn_compute_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); + XNN_PRIVATE void xnn_compute_dqgemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t mr_block_start, @@ -378,6 +383,11 @@ struct gemm_context { size_t mr_block_size, size_t nr_block_size); + XNN_PRIVATE void xnn_compute_hmp_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t group_index, size_t mr_block_start, + size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); + XNN_PRIVATE void xnn_compute_hmp_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, @@ -1198,29 +1208,6 @@ struct elementwise_binary_context { size_t i, size_t j, size_t k, size_t l, size_t m); #endif -struct channel_shuffle_context { - const void* x; - size_t x_stride; - void* y; - size_t y_stride; - size_t n; - size_t m; - union { - xnn_zipc_ukernel_fn fixed_ukernel; - xnn_zipv_ukernel_fn variable_ukernel; - }; -}; - -#ifndef __cplusplus - XNN_PRIVATE void xnn_compute_channel_shuffle_fixed( - const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t index); - - XNN_PRIVATE void xnn_compute_channel_shuffle_variable( - const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t index); -#endif - struct lut_strided_context { size_t n; const void* x; @@ -1302,9 +1289,8 @@ struct reduce_context { xnn_rdsum_ukernel_fn rdsum; } ukernel; xnn_vunary_ukernel_fn cvt_ukernel; - xnn_vunary_ukernel_fn s32_f32_cvt_ukernel; - xnn_vunary_ukernel_fn u32_f32_cvt_ukernel; struct xnn_reduce_params params; + union xnn_unary_uparams cvt_params; }; #ifndef __cplusplus @@ -1438,10 +1424,19 @@ struct f32_qd8_convert_context { const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index); + XNN_PRIVATE void xnn_compute_f16_qdu8_convert( + const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); + XNN_PRIVATE void xnn_compute_f32_qd8_convert( const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index); + XNN_PRIVATE void xnn_compute_f32_qdu8_convert( + const struct f32_qd8_convert_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); + XNN_PRIVATE void xnn_compute_pad_qd8_params( const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], size_t batch_index); @@ -1471,6 +1466,7 @@ struct f32_qd8_convert_context { size_t mr; size_t kr; size_t sr; + size_t group_stride; const float* XNN_RESTRICT lhs; size_t lhs_stride; int8_t* XNN_RESTRICT lhs_packed; @@ -1481,7 +1477,7 @@ struct f32_qd8_convert_context { XNN_PRIVATE void xnn_compute_f32_qp8_convert( const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t m_idx_start); + size_t group_idx, size_t m_idx_start, size_t m_tile); #endif struct u8_softmax_context { diff --git a/src/xnnpack/config-types.h b/src/xnnpack/config-types.h index 79506ca06bbb..7262d376ef37 100644 --- a/src/xnnpack/config-types.h +++ b/src/xnnpack/config-types.h @@ -9,6 +9,7 @@ #include #include +#include "xnnpack/hardware-config.h" #include "xnnpack/microfnptr.h" #ifdef __cplusplus @@ -203,6 +204,7 @@ struct xnn_gemm_config { uint8_t log2_sr; uint8_t planes; // number of 4 bit planes (1 for legacy, 2 for unzip) uint8_t mr_packed; // `mr` value used for packed left-hand operands. + enum xnn_arch_flags arch; }; struct xnn_maxpool_config { diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 597b03b1b916..0fd3d1bf933a 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -251,7 +251,14 @@ XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f32_qb4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f32_qc4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f32_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qp8_f32_qc4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qp8_f32_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qp8_f32_qb4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f32_qc4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f16_qc8w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f32_qc8w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f32_qb4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f16_qc4w_gemm_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f16_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qu8_gemm_config(); @@ -260,9 +267,6 @@ XNN_INTERNAL const struct xnn_maxpool_config* xnn_init_f32_maxpool_config(); XNN_INTERNAL const struct xnn_maxpool_config* xnn_init_s8_maxpool_config(); XNN_INTERNAL const struct xnn_maxpool_config* xnn_init_u8_maxpool_config(); -XNN_INTERNAL const struct xnn_zip_config* xnn_init_x8_zip_config(); -XNN_INTERNAL const struct xnn_zip_config* xnn_init_x32_zip_config(); - XNN_INTERNAL const struct xnn_rmax_config* xnn_init_f16_rmax_config(); XNN_INTERNAL const struct xnn_rmax_config* xnn_init_f32_rmax_config(); XNN_INTERNAL const struct xnn_rmax_config* xnn_init_u8_rmax_config(); diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index e88a349d250b..fe2017be2d01 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -174,6 +174,8 @@ DECLARE_F16_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f16_gemm_minmax_ukernel_8x64__avx51 size_t cn_stride, \ const union xnn_f32_minmax_params params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +DECLARE_PF32_GEMM_MINMAX_UKERNEL_FUNCTION( + xnn_pf32_gemm_minmax_ukernel_1x32__neonsme2) DECLARE_PF32_GEMM_MINMAX_UKERNEL_FUNCTION( xnn_pf32_gemm_minmax_ukernel_32x32__neonsme2) @@ -340,6 +342,46 @@ DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x8__avx_br DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x16__avx_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x8__avx_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane) + +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast) + +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast) + DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x8__fma3_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16__fma3_broadcast) @@ -2317,6 +2359,24 @@ DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_u DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2) DECLARE_QP8_F32_QC4W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x8c16s2__neoni8mm_mstep2) +#define DECLARE_QP8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ + XNN_INTERNAL void fn_name( \ + size_t m, \ + size_t n, \ + size_t k, \ + const void* lhs_packed, \ + const void* rhs_packed, \ + float* dst, \ + size_t dst_stride_row, \ + size_t dst_stride_col, \ + union xnn_f32_minmax_params \ + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); + +DECLARE_QP8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4) +DECLARE_QP8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4) +DECLARE_QP8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot) +DECLARE_QP8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot) + #define DECLARE_QP8_F32_QB4W_GEMM_MINMAX_UKERNEL_FUNCTION(fn_name) \ XNN_INTERNAL void fn_name( \ size_t m, \ @@ -2466,6 +2526,44 @@ DECLARE_QD8_F16_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f16_qc8w_gemm_minmax_u const union xnn_f32_minmax_params params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)], \ const struct xnn_qd8_quantization_params quantization_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni) + +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni) + +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane) +DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane) + DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55) DECLARE_QD8_F32_QC8W_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_cortex_a55) diff --git a/src/xnnpack/hardware-config.h b/src/xnnpack/hardware-config.h index dee719c44536..91bb19965ed0 100644 --- a/src/xnnpack/hardware-config.h +++ b/src/xnnpack/hardware-config.h @@ -15,17 +15,18 @@ extern "C" { #endif +// These flags should be sorted by preference (a < b ==> a slower than b). enum xnn_arch_flags { #if XNN_ARCH_ARM || XNN_ARCH_ARM64 xnn_arch_arm_v6 = 1 << 0, xnn_arch_arm_vfpv2 = 1 << 1, xnn_arch_arm_vfpv3 = 1 << 2, xnn_arch_arm_neon = 1 << 3, - xnn_arch_arm_neon_fp16 = 1 << 4, - xnn_arch_arm_neon_fma = 1 << 5, - xnn_arch_arm_neon_v8 = 1 << 6, - xnn_arch_arm_fp16_arith = 1 << 7, - xnn_arch_arm_neon_fp16_arith = 1 << 8, + xnn_arch_arm_neon_fma = 1 << 4, + xnn_arch_arm_neon_v8 = 1 << 5, + xnn_arch_arm_fp16_arith = 1 << 6, + xnn_arch_arm_neon_fp16_arith = 1 << 7, + xnn_arch_arm_neon_fp16 = 1 << 8, xnn_arch_arm_neon_bf16 = 1 << 9, xnn_arch_arm_neon_dot = 1 << 10, xnn_arch_arm_neon_i8mm = 1 << 11, @@ -41,16 +42,16 @@ enum xnn_arch_flags { xnn_arch_x86_f16c = 1 << 3, xnn_arch_x86_fma3 = 1 << 4, xnn_arch_x86_avx2 = 1 << 5, - xnn_arch_x86_avxvnni = 1 << 6, - xnn_arch_x86_avxvnniint8 = 1 << 7, - xnn_arch_x86_avx256skx = 1 << 8, - xnn_arch_x86_avx256vnni = 1 << 9, - xnn_arch_x86_avx256vnnigfni = 1 << 10, - xnn_arch_x86_avx512f = 1 << 11, - xnn_arch_x86_avx512vbmi = 1 << 12, - xnn_arch_x86_avx512skx = 1 << 13, - xnn_arch_x86_avx512vnni = 1 << 14, - xnn_arch_x86_avx512vnnigfni = 1 << 15, + xnn_arch_x86_avx512f = 1 << 6, + xnn_arch_x86_avx512vbmi = 1 << 7, + xnn_arch_x86_avx512skx = 1 << 8, + xnn_arch_x86_avx512vnni = 1 << 9, + xnn_arch_x86_avx512vnnigfni = 1 << 10, + xnn_arch_x86_avxvnni = 1 << 11, + xnn_arch_x86_avxvnniint8 = 1 << 12, + xnn_arch_x86_avx256skx = 1 << 13, + xnn_arch_x86_avx256vnni = 1 << 14, + xnn_arch_x86_avx256vnnigfni = 1 << 15, xnn_arch_x86_avx512amx = 1 << 16, xnn_arch_x86_avx512fp16 = 1 << 17, #endif @@ -146,6 +147,17 @@ struct xnn_hardware_config { #if XNN_ARCH_HEXAGON bool use_hvx; #endif // XNN_ARCH_HEXAGON + // Size in bytes of the L1 data cache. + size_t l1_data_cache_bytes; + size_t l1_data_cache_line_size; + size_t l1_data_cache_associativity; + size_t l1_data_cache_num_sets; + + // Size in bytes of the L2 data cache. + size_t l2_data_cache_bytes; + size_t l2_data_cache_line_size; + size_t l2_data_cache_associativity; + size_t l2_data_cache_num_sets; }; XNN_INTERNAL const struct xnn_hardware_config* xnn_init_hardware_config(); diff --git a/src/xnnpack/indirection.h b/src/xnnpack/indirection.h index e5e71f852bb8..a7276b2abbaf 100644 --- a/src/xnnpack/indirection.h +++ b/src/xnnpack/indirection.h @@ -114,10 +114,23 @@ XNN_INTERNAL void xnn_indirection_init_subconv2d( uint32_t log2_element_size); XNN_INTERNAL void xnn_indirection_init_maxpool2d( - xnn_operator_t op, - size_t step_height, - size_t step_width, - uint32_t log2_element_size); + const void** indirection_buffer, + const void* input, + const size_t input_pixel_stride, + const size_t input_height, + const size_t input_width, + const size_t output_height, + const size_t output_width, + const size_t kernel_height, + const size_t kernel_width, + const size_t stride_height, + const size_t stride_width, + const size_t dilation_height, + const size_t dilation_width, + const size_t input_padding_top, + const size_t input_padding_left, + const size_t step_height, + const size_t step_width); XNN_INTERNAL void xnn_indirection_init_resize_bilinear2d_hwc_f16( size_t output_y_start, diff --git a/src/xnnpack/internal.h b/src/xnnpack/internal.h index 954dcc68946a..f0d745f8e734 100644 --- a/src/xnnpack/internal.h +++ b/src/xnnpack/internal.h @@ -12,16 +12,13 @@ #include #include "xnnpack.h" +#include "xnnpack/config-types.h" #include "pthreadpool.h" #ifdef __cplusplus extern "C" { #endif -/// If set, try to pack the quantized values for use by a GEMM. -#define XNN_FLAG_MAYBE_PACK_FOR_GEMM 0x00000080 -#define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100 - enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( size_t input_channels, // size_t output_channels, // @@ -38,20 +35,71 @@ enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc4w( xnn_weights_cache_t weights_cache, // xnn_operator_t* fully_connected_op_out); +enum xnn_status xnn_create_fully_connected_nc_qp8_f32_qc8w( + size_t input_channels, // + size_t output_channels, // + size_t input_stride, // + size_t output_stride, // + const float* kernel_scale, // + const void* kernel, // + const float* bias, // + float output_min, // + float output_max, // + uint32_t flags, // + xnn_code_cache_t code_cache, // + xnn_weights_cache_t weights_cache, // + xnn_operator_t* fully_connected_op_out); + enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc4w( xnn_operator_t fully_connected_op, // const int8_t* input, // float* output); +enum xnn_status xnn_setup_fully_connected_nc_qp8_f32_qc8w( + xnn_operator_t fully_connected_op, // + const int8_t* input, // + float* output); + enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qc4w( xnn_operator_t fully_connected_op, // size_t batch_size, // pthreadpool_t threadpool); -enum xnn_status xnn_create_convert_nc_f32_qp8(uint32_t flags, // - xnn_operator_t* convert_op_out); +enum xnn_status xnn_reshape_fully_connected_nc_qp8_f32_qc8w( + xnn_operator_t fully_connected_op, // + size_t batch_size, // + pthreadpool_t threadpool); + +enum xnn_status xnn_create_batch_matrix_multiply_nc_qp8_f32_qc8w( + size_t batch_size_b, // + size_t k, // + size_t n, // + const int8_t* data_b, // + const float* scale_b, // + uint32_t flags, xnn_operator_t* batch_matrix_multiply_op_out); + +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qp8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, // + size_t num_batch_dims, // + const size_t* batch_dims_a, // + const size_t* batch_dims_b, // + size_t m, // + size_t k, // + size_t n, // + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_qp8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, // + const int8_t* input_a, // + float* output); + +enum xnn_status xnn_create_convert_nc_f32_qp8( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + xnn_operator_t* convert_op_out); enum xnn_status xnn_reshape_convert_nc_f32_qp8(xnn_operator_t convert_op, // + size_t num_groups, // size_t batch_size, // size_t channels, // size_t input_stride, // @@ -109,6 +157,18 @@ enum xnn_status xnn_create_fully_connected_nc_pf32( xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out); +enum xnn_status xnn_create_convolution2d_nchw_f32_f16( + uint32_t input_padding_top, uint32_t input_padding_right, + uint32_t input_padding_bottom, uint32_t input_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height, + uint32_t subsampling_width, uint32_t dilation_height, + uint32_t dilation_width, uint32_t groups, size_t group_input_channels, + size_t group_output_channels, size_t input_channel_stride, + size_t output_channel_stride, const void* kernel, const void* bias, + float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + enum xnn_status xnn_create_convolution2d_nhwc_pf32( uint32_t input_padding_top, uint32_t input_padding_right, uint32_t input_padding_bottom, uint32_t input_padding_left, @@ -121,6 +181,188 @@ enum xnn_status xnn_create_convolution2d_nhwc_pf32( xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out); +// quantization_params must be padded with at least +// XNN_EXTRA_QUANTIZATION_PARAMS entries. +enum xnn_status xnn_setup_convert_nc_f16_qdu8( + xnn_operator_t convert_op, const void* input, uint8_t* output, + struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_convert_nc_f16_qdu8(uint32_t flags, + xnn_operator_t* convert_op_out); + +enum xnn_status xnn_reshape_convert_nc_f16_qdu8( + xnn_operator_t convert_op, size_t batch_size, size_t channels, + size_t input_stride, size_t output_stride, pthreadpool_t threadpool); + +enum xnn_status xnn_create_convert_nc_f32_qdu8(uint32_t flags, + xnn_operator_t* convert_op_out); + +enum xnn_status xnn_reshape_convert_nc_f32_qdu8( + xnn_operator_t convert_op, size_t batch_size, size_t channels, + size_t input_stride, size_t output_stride, pthreadpool_t threadpool); + +// quantization_params must be padded with at least +// XNN_EXTRA_QUANTIZATION_PARAMS entries. +enum xnn_status xnn_setup_convert_nc_f32_qdu8( + xnn_operator_t convert_op, const float* input, uint8_t* output, + struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f16_qc8w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, const float* kernel_scale, const int8_t* kernel, + const float* bias, float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f16_qc8w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f16_qc8w( + xnn_operator_t fully_connected_op, const int8_t* input, float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qc8w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, const float* kernel_scale, const int8_t* kernel, + const float* bias, float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f32_qc8w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f32_qc8w( + xnn_operator_t fully_connected_op, const int8_t* input, float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_convolution2d_nhwc_qdu8_f32_qc8w( + uint32_t input_padding_top, uint32_t input_padding_right, + uint32_t input_padding_bottom, uint32_t input_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height, + uint32_t subsampling_width, uint32_t dilation_height, + uint32_t dilation_width, uint32_t groups, size_t group_input_channels, + size_t group_output_channels, size_t input_channel_stride, + size_t output_channel_stride, const float* kernel_scale, + const int8_t* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t convolution_op, size_t batch_size, size_t input_height, + size_t input_width, size_t* workspace_size, size_t* workspace_alignment, + size_t* output_height_out, size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t convolution_op, void* workspace, const uint8_t* input, + float* output, const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_convolution2d_nhwc_qdu8_f16_qc8w( + uint32_t input_padding_top, uint32_t input_padding_right, + uint32_t input_padding_bottom, uint32_t input_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height, + uint32_t subsampling_width, uint32_t dilation_height, + uint32_t dilation_width, uint32_t groups, size_t group_input_channels, + size_t group_output_channels, size_t input_channel_stride, + size_t output_channel_stride, const float* kernel_scale, + const int8_t* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qdu8_f16_qc8w( + xnn_operator_t convolution_op, size_t batch_size, size_t input_height, + size_t input_width, size_t* workspace_size, size_t* workspace_alignment, + size_t* output_height_out, size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_qdu8_f16_qc8w( + xnn_operator_t convolution_op, void* workspace, const int8_t* input, + void* output, const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qc4w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, uint8_t kernel_zero_point, const float* kernel_scale, + const void* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f32_qc4w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f32_qc4w( + xnn_operator_t fully_connected_op, const uint8_t* input, float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_deconvolution2d_nhwc_qdu8_f32_qc8w( + uint32_t output_padding_top, uint32_t output_padding_right, + uint32_t output_padding_bottom, uint32_t output_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t stride_height, + uint32_t stride_width, uint32_t dilation_height, uint32_t dilation_width, + uint32_t groups, size_t group_input_channels, size_t group_output_channels, + size_t input_pixel_stride, size_t output_pixel_stride, + const float* kernel_scale, const int8_t* kernel, const float* bias, + float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t deconvolution_op, size_t batch_size, size_t input_height, + size_t input_width, uint32_t adjustment_height, uint32_t adjustment_width, + size_t* output_height_out, size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_qdu8_f32_qc8w( + xnn_operator_t deconvolution_op, const int8_t* input, float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qb4w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, size_t block_size, uint8_t kernel_zero_point, + const uint16_t* kernel_scale, const void* kernel, const float* bias, + float output_min, float output_max, uint32_t flags, + xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f32_qb4w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f32_qb4w( + xnn_operator_t fully_connected_op, const int8_t* input, float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qdu8_f16_qc4w( + size_t input_channels, size_t output_channels, size_t input_stride, + size_t output_stride, uint8_t kernel_zero_point, const float* kernel_scale, + const void* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_setup_fully_connected_nc_qdu8_f16_qc4w( + xnn_operator_t fully_connected_op, const int8_t* input, void* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_reshape_fully_connected_nc_qdu8_f16_qc4w( + xnn_operator_t fully_connected_op, size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_batch_matrix_multiply_nc_qdu8_f32_qc8w( + size_t batch_size_b, size_t k, size_t n, const int8_t* data_b, + const float* scale_b, uint32_t flags, + xnn_operator_t* batch_matrix_multiply_op); + +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qdu8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims, + const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k, + size_t n, pthreadpool_t threadpool); + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_qdu8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, const int8_t* input_a, + const struct xnn_quantization_params* quantization_params, float* output); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/xnnpack/math.h b/src/xnnpack/math.h index d1dc40241033..24a66e88d320 100644 --- a/src/xnnpack/math.h +++ b/src/xnnpack/math.h @@ -510,9 +510,13 @@ XNN_INLINE static uint16_t math_cvt_bf16_fp32(float x) { #define XNN_HAVE_FLOAT16 1 #endif +#ifndef XNN_HAVE_FLOAT16 +#define XNN_HAVE_FLOAT16 0 +#endif + #endif // XNN_HAVE_FLOAT16 -#ifdef XNN_HAVE_FLOAT16 +#if XNN_HAVE_FLOAT16 typedef _Float16 xnn_float16; #else // We want float16s to be a distinct type from uint16_t, to avoid accidental @@ -550,7 +554,7 @@ extern "C" { #endif XNN_INLINE static xnn_float16 xnn_float16_from_float(float f) { -#ifdef XNN_HAVE_FLOAT16 +#if XNN_HAVE_FLOAT16 return (xnn_float16) f; #else struct xnn_float16 result; @@ -560,7 +564,7 @@ XNN_INLINE static xnn_float16 xnn_float16_from_float(float f) { } XNN_INLINE static float xnn_float16_to_float(xnn_float16 fp16) { -#ifdef XNN_HAVE_FLOAT16 +#if XNN_HAVE_FLOAT16 return (float) fp16; #else return fp16_ieee_to_fp32_value(fp16.value); @@ -602,7 +606,7 @@ XNN_INLINE static xnn_bfloat16 xnn_bfloat16_from_bits(uint16_t x) { } XNN_INLINE static xnn_float16 xnn_float16_zero() { -#ifdef XNN_HAVE_FLOAT16 +#if XNN_HAVE_FLOAT16 return (xnn_float16) 0.0f; #else struct xnn_float16 result; diff --git a/src/xnnpack/microfnptr.h b/src/xnnpack/microfnptr.h index edd1a8aae033..82e52c4a2473 100644 --- a/src/xnnpack/microfnptr.h +++ b/src/xnnpack/microfnptr.h @@ -340,6 +340,18 @@ typedef void (*xnn_qp8_f32_qc4w_gemm_minmax_ukernel_fn)( union xnn_f32_minmax_params minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); +typedef void (*xnn_qp8_f32_qc8w_gemm_minmax_ukernel_fn)( + size_t m, + size_t n, + size_t k, + const void* lhs_packed, + const void* rhs_packed, + float* dst, + size_t dst_stride_row, + size_t dst_stride_col, + union xnn_f32_minmax_params + minmax_params[XNN_RESTRICT XNN_MIN_ELEMENTS(1)]); + typedef void (*xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn)( size_t m, size_t n, @@ -2031,8 +2043,7 @@ typedef size_t (*xnn_init_reduce_params_fn)( typedef size_t (*xnn_update_reduce_params_fn)( struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)], - float scale, - int32_t num_elements); + float scale); typedef size_t (*xnn_init_qs8_qc8w_conv_minmax_params_fn)( union xnn_qs8_qc8w_conv_minmax_params params[XNN_MIN_ELEMENTS(1)], @@ -2301,8 +2312,11 @@ struct xnn_hmp_qp8gemm_bl_ukernel { }; // Largest GEMM/IGEMM MR used in init.c is 16 (x86 AVX512AMX). -// Largest GEMM/IGEMM MR is 8 in e2e benchmarks. +#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI #define XNN_MAX_MR 32 +#else +#define XNN_MAX_MR 16 +#endif struct gemm_fused_ukernels { union { diff --git a/src/xnnpack/microparams-init.h b/src/xnnpack/microparams-init.h index 2a30be88e113..f9f9c7523093 100644 --- a/src/xnnpack/microparams-init.h +++ b/src/xnnpack/microparams-init.h @@ -163,8 +163,7 @@ DECLARE_INIT_REDUCE_PARAMS_FUNCTION(xnn_init_qu8_reduce_scalar_params); #define DECLARE_UPDATE_REDUCE_PARAMS_FUNCTION(fn_name) \ XNN_INTERNAL size_t fn_name( \ struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)], \ - float scale, \ - int32_t num_elements); + float scale); DECLARE_UPDATE_REDUCE_PARAMS_FUNCTION(xnn_update_f32_reduce_scalar_params); DECLARE_UPDATE_REDUCE_PARAMS_FUNCTION(xnn_update_qs8_reduce_scalar_params); @@ -271,7 +270,6 @@ DECLARE_INIT_UNARY_MICROPARAMS_FUNCTION(xnn_init_qs8_f16_cvt_scalar_params); DECLARE_INIT_UNARY_MICROPARAMS_FUNCTION(xnn_init_qs8_f32_cvt_scalar_params); DECLARE_INIT_UNARY_MICROPARAMS_FUNCTION(xnn_init_qu8_cvt_scalar_params); DECLARE_INIT_UNARY_MICROPARAMS_FUNCTION(xnn_init_qu8_f32_cvt_scalar_params); -DECLARE_INIT_UNARY_MICROPARAMS_FUNCTION(xnn_init_s32_f32_cvt_scalar_params); XNN_INTERNAL size_t xnn_init_qs8_add_minmax_scalar_params( struct xnn_qs8_add_minmax_params uparams[XNN_MIN_ELEMENTS(1)], diff --git a/src/xnnpack/microparams.h b/src/xnnpack/microparams.h index 311489b77a82..11c6069ca4f4 100644 --- a/src/xnnpack/microparams.h +++ b/src/xnnpack/microparams.h @@ -408,7 +408,6 @@ struct xnn_f32_reduce_params { }; struct xnn_qs8_reduce_params { - int32_t num_elements; float scale; float input_output_scale; int8_t input_zero_point; @@ -416,7 +415,6 @@ struct xnn_qs8_reduce_params { }; struct xnn_qu8_reduce_params { - int32_t num_elements; float scale; float input_output_scale; uint8_t input_zero_point; @@ -474,12 +472,6 @@ struct xnn_f32_qu8_cvt_params { } scalar; }; -struct xnn_s32_f32_cvt_params { - struct { - int32_t zero_point; - } scalar; -}; - struct xnn_qs8_cvt_params { struct { int16_t input_zero_point; @@ -578,15 +570,20 @@ struct xnn_qs8_qc4w_packing_params { uint8_t kernel_zero_point; }; +struct xnn_qs8_qc8w_packing_params { + int8_t input_zero_point; + float scale_multiplier; +}; + struct xnn_x32_packb_params { char _; // Dummy member variable to comply with the C standard }; struct xnn_unary_reference_params { float x_scale; - int32_t x_zero_point; + float x_zero_point; float inv_y_scale; - int32_t y_zero_point; + float y_zero_point; union xnn_unary_params params; }; @@ -597,7 +594,6 @@ union xnn_unary_uparams { struct xnn_qs8_f32_cvt_params qs8_f32_cvt; struct xnn_qu8_f32_cvt_params qu8_f32_cvt; struct xnn_qs8_f16_cvt_params qs8_f16_cvt; - struct xnn_s32_f32_cvt_params s32_f32_cvt; struct xnn_qs8_cvt_params qs8_cvt; struct xnn_qu8_cvt_params qu8_cvt; struct xnn_f16_elu_params f16_elu; diff --git a/src/xnnpack/operator-type-defs.h b/src/xnnpack/operator-type-defs.h index 6e72e2305f98..ebc60a45b82b 100644 --- a/src/xnnpack/operator-type-defs.h +++ b/src/xnnpack/operator-type-defs.h @@ -16,20 +16,27 @@ XNN_ENUM_ITEM(xnn_operator_type_average_pooling_nhwc_qu8, "Average Pooling (NHWC XNN_ENUM_ITEM(xnn_operator_type_batch_matrix_multiply_nc_f16, "Batch Matrix Multiply (NC, F16)") XNN_ENUM_ITEM(xnn_operator_type_batch_matrix_multiply_nc_f32, "Batch Matrix Multiply (NC, F32)") XNN_ENUM_ITEM(xnn_operator_type_batch_matrix_multiply_nc_qd8_f32_qc8w, "Batch Matrix Multiply (NC, QD8, F32, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_batch_matrix_multiply_nc_qdu8_f32_qc8w, "Batch Matrix Multiply (NC, QDU8, F32, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_batch_matrix_multiply_nc_qp8_f32_qc8w, + "Batch Matrix Multiply (NC, QP8, F32, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_binary_elementwise, "Binary Elementwise (ND)") -XNN_ENUM_ITEM(xnn_operator_type_channel_shuffle_nc_x8, "Channel Shuffle (NC, X8)") -XNN_ENUM_ITEM(xnn_operator_type_channel_shuffle_nc_x32, "Channel Shuffle (NC, X32)") XNN_ENUM_ITEM(xnn_operator_type_constant_pad_nd_x8, "Constant Pad (ND, X8)") XNN_ENUM_ITEM(xnn_operator_type_constant_pad_nd_x16, "Constant Pad (ND, X16)") XNN_ENUM_ITEM(xnn_operator_type_constant_pad_nd_x32, "Constant Pad (ND, X32)") XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f16_qd8, "Convert (NC, F16, QD8)") +XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f16_qdu8, "Convert (NC, F16, QDU8)") XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f32_qd8, "Convert (NC, F32, QD8)") +XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f32_qdu8, "Convert (NC, F32, QDU8)") XNN_ENUM_ITEM(xnn_operator_type_convert_nc_f32_qp8, "Convert (NC, F32, QP8)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nchw_f16, "Convolution (NCHW, F16)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nchw_f32, "Convolution (NCHW, F32)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_f16, "Convolution (NHWC, F16)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_f32, "Convolution (NHWC, F32)") +XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_qdu8_f16_qc8w, + "Convolution (NHWC, QD8, F16, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_qd8_f16_qc8w, "Convolution (NHWC, QD8, F16, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_qdu8_f32_qc8w, + "Convolution (NHWC, QDU8, F32, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_qd8_f32_qc8w, "Convolution (NHWC, QD8, F32, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_qc8, "Convolution (NHWC, QC8)") XNN_ENUM_ITEM(xnn_operator_type_convolution_nhwc_qs8, "Convolution (NHWC, QS8)") @@ -40,6 +47,8 @@ XNN_ENUM_ITEM(xnn_operator_type_copy_nc_x32, "Copy (NC, X32)") XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_f16, "Deconvolution (NHWC, F16)") XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_f32, "Deconvolution (NHWC, F32)") XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_qd8_f32_qc8w, "Deconvolution (NHWC, QD8, F32, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_qdu8_f32_qc8w, + "Deconvolution (NHWC, QDU8, F32, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_qs8, "Deconvolution (NHWC, QS8)") XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_qs8_qc8w, "Deconvolution (NC, QS8, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_deconvolution_nhwc_qu8, "Deconvolution (NHWC, QU8)") @@ -58,11 +67,23 @@ XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_pf32, "Fully Connected (NC, PF32)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f16_qb4w, "Fully Connected (NC, QD8, F16, QB4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f16_qc4w, "Fully Connected (NC, QD8, F16, QC4W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qdu8_f16_qc4w, + "Fully Connected (NC, QDU8, F16, QC4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f16_qc8w, "Fully Connected (NC, QD8, F16, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qdu8_f16_qc8w, + "Fully Connected (NC, QDU8, F16, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f32_qb4w, "Fully Connected (NC, QD8, F32, QB4W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qdu8_f32_qb4w, + "Fully Connected (NC, QDU8, F32, QB4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f32_qc4w, "Fully Connected (NC, QD8, F32, QC4W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qdu8_f32_qc4w, + "Fully Connected (NC, QDU8, F32, QC4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qd8_f32_qc8w, "Fully Connected (NC, QD8, F32, QC8W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qdu8_f32_qc8w, + "Fully Connected (NC, QDU8, F32, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qp8_f32_qc4w, "Fully Connected (NC, QP8, F32, QC4W)") +XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qp8_f32_qc8w, + "Fully Connected (NC, QP8, F32, QC8W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qp8_f32_qb4w, "Fully Connected (NC, QP8, F32, QB4W)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qs8, "Fully Connected (NC, QS8)") XNN_ENUM_ITEM(xnn_operator_type_fully_connected_nc_qs8_qc8w, "Fully Connected (NC, QS8, QC8W)") diff --git a/src/xnnpack/operator-utils.h b/src/xnnpack/operator-utils.h index 7387f32fa11d..565c6cd12a25 100644 --- a/src/xnnpack/operator-utils.h +++ b/src/xnnpack/operator-utils.h @@ -12,18 +12,18 @@ #include "xnnpack/operator.h" #include "xnnpack/params.h" +static inline bool use_weights_cache(struct xnn_operator* op) { + return op->weights_cache != NULL; +} + static inline void* packed_weights(struct xnn_operator* op) { - if (op->weights_cache == NULL) { - return op->packed_weights.pointer; - } else { + if (use_weights_cache(op)) { return op->weights_cache->offset_to_addr(op->weights_cache->context, op->packed_weights.offset); + } else { + return op->packed_weights.pointer; } } -static inline bool use_weights_cache(struct xnn_operator* op) { - return op->weights_cache != NULL; -} - // Get a pointer to a region to pack weights into. If weights cache is available, use it, returning to a pointer to the // cache's buffer, otherwise, allocate and return a pointer to a new region. Returns NULL on error. XNN_INTERNAL void* xnn_get_pointer_to_write_weights( diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h index 6dd20664472d..53fdbd3a66d9 100644 --- a/src/xnnpack/operator.h +++ b/src/xnnpack/operator.h @@ -63,6 +63,7 @@ struct xnn_ukernel_gemm { xnn_packw_gemm_goi_ukernel_fn packw_gemm_goi; xnn_packw_gemm_gio_ukernel_fn packw_gemm_gio; uint8_t mr; + uint8_t mr_packed; uint8_t nr; uint8_t kr; uint8_t sr; @@ -255,6 +256,7 @@ struct xnn_operator { // but params need to be swapped for commutative ops with per-operand params. union { union xnn_binary_uparams binary; + union xnn_unary_uparams unary; struct xnn_f16_default_params f16_default; union xnn_f32_minmax_params f32_minmax; struct xnn_f32_default_params f32_default; @@ -281,8 +283,6 @@ struct xnn_operator { const struct xnn_reduce_config* rdsum_config; const struct xnn_reduce_config* rsum_config; const struct xnn_unary_elementwise_config* cvt_config; - const struct xnn_unary_elementwise_config* s32_f32_cvt_config; - const struct xnn_unary_elementwise_config* u32_f32_cvt_config; }; const struct xnn_ibilinear_chw_config* ibilinear_chw_config; const struct xnn_ibilinear_config* ibilinear_config; @@ -311,7 +311,10 @@ struct xnn_operator { const struct xnn_binary_elementwise_config* binary_elementwise_config; struct { const struct xnn_unary_elementwise_config* unary_elementwise_config; - const struct xnn_reduce_config* rminmax_config; // For dynamic quantization convert operator. + const struct xnn_reduce_config* + rminmax_config; // For dynamic quantization convert operator. + const struct xnn_gemm_config* + gemm_config; // For dynamic quantization convert operator. }; // For unary elementwise operators. struct { const struct xnn_rmax_config* rmax_config; @@ -329,7 +332,6 @@ struct xnn_operator { union { struct argmax_pooling_context argmax_pooling; struct average_pooling_context average_pooling; - struct channel_shuffle_context channel_shuffle; struct conv2d_context conv2d; struct dwconv2d_context dwconv2d; struct { diff --git a/src/xnnpack/pack.h b/src/xnnpack/pack.h index 8850d370249a..f48120e2f31f 100644 --- a/src/xnnpack/pack.h +++ b/src/xnnpack/pack.h @@ -479,6 +479,30 @@ XNN_INTERNAL size_t xnn_packed_stride_kai_qs4_weights_and_biases( size_t k_stride, // size_t extra_bytes); +XNN_INTERNAL void xnn_pack_kai_qs8_weights_and_biases( + uint32_t flags, // + const struct xnn_gemm_config* gemm_config, // + size_t input_channels, // + size_t output_channels, // + size_t groups, // + size_t k_stride, // + const void* accumulator_init, // + const void* weights, // + xnn_init_scale_params_fn init_extra_data0_fn, // + const void* extra_data0, // + size_t extra_data0_element_size, // + xnn_init_scale_params_fn init_extra_data1_fn, // + const void* extra_data1, // + size_t extra_data1_element_size, // + void* packed_weights_ptr, // + const void* params); + +XNN_INTERNAL size_t xnn_packed_stride_kai_qs8_weights_and_biases( + const struct xnn_gemm_config* gemm_config, // + size_t k, // + size_t k_stride, // + size_t extra_bytes); + size_t xnn_packed_stride_kai_f32_weights_and_biases( const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride, size_t extra_bytes); diff --git a/src/xnnpack/packq.h b/src/xnnpack/packq.h index b7576cbde82c..30f5f42928dc 100644 --- a/src/xnnpack/packq.h +++ b/src/xnnpack/packq.h @@ -12,7 +12,6 @@ #include "xnnpack/common.h" #include "xnnpack/config-types.h" -#include "xnnpack/config.h" #include "xnnpack/math.h" #ifdef __cplusplus @@ -57,13 +56,11 @@ XNN_INLINE static size_t xnn_x8_packq_f32qp8_packed_size(size_t m, size_t k, return num_rows * lhs_packed_stride(k, mr_packed, kr, sr); } -XNN_INLINE static size_t xnn_x8_packq_f32qp8_gemm_packed_size(size_t m, - size_t k) { - const struct xnn_gemm_config* gemm_config = - xnn_init_qp8_f32_qc4w_gemm_config(); - assert(gemm_config != NULL); - - const uint32_t mr_packed = m == 1 ? 1 : gemm_config->mr_packed; +XNN_INLINE static size_t xnn_x8_packq_f32qp8_gemm_packed_size( + const struct xnn_gemm_config* gemm_config, size_t m, size_t k) { + const uint32_t mr_packed = m == 1 ? 1 + : gemm_config->mr_packed ? gemm_config->mr_packed + : gemm_config->mr; const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; diff --git a/src/xnnpack/quantization.h b/src/xnnpack/quantization.h index edfef28faf0d..484a25341798 100644 --- a/src/xnnpack/quantization.h +++ b/src/xnnpack/quantization.h @@ -39,6 +39,41 @@ static inline struct xnn_qd8_quantization_params xnn_qd8_asymmetric_quantization return quantization_params; } +static inline struct xnn_qd8_quantization_params +xnn_qdu8_asymmetric_quantization_params(float min, float max) { + struct xnn_qd8_quantization_params quantization_params; + const float rmin = math_min_f32(0.0f, min); + const float rmax = math_max_f32(0.0f, max); + const float qmin = 0; + const float qmax = UINT8_MAX; + const float scale = rmin == rmax ? 1.f : (qmax - qmin) / (rmax - rmin); + int32_t zero_point = lrintf(-rmin * scale); + quantization_params.inv_scale = scale; + quantization_params.zero_point = zero_point; + return quantization_params; +} + +static inline struct xnn_qd8_quantization_params +xnn_f16_qdu8_asymmetric_quantization_params(xnn_float16 min, xnn_float16 max, + xnn_float16* f16_scale) { + struct xnn_qd8_quantization_params params = + xnn_qdu8_asymmetric_quantization_params(xnn_float16_to_float(min), + xnn_float16_to_float(max)); + *f16_scale = xnn_float16_from_float(params.inv_scale); + params.inv_scale = 1.f / params.inv_scale; + return params; +} + +static inline struct xnn_qd8_quantization_params +xnn_f32_qdu8_asymmetric_quantization_params(float min, float max, + float* f32_scale) { + struct xnn_qd8_quantization_params params = + xnn_qdu8_asymmetric_quantization_params(min, max); + *f32_scale = params.inv_scale; + params.inv_scale = 1.f / params.inv_scale; + return params; +} + static inline struct xnn_qd8_quantization_params xnn_f32_qd8_asymmetric_quantization_params( float min, float max, float* f32_scale) { diff --git a/src/xnnpack/raddextexp.h b/src/xnnpack/raddextexp.h index 6ee470a50f2b..9e08b6d4438f 100644 --- a/src/xnnpack/raddextexp.h +++ b/src/xnnpack/raddextexp.h @@ -14,38 +14,13 @@ extern "C" { #endif -#define DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(fn_name) \ +#define XNN_UKERNEL(arch_flags, fn_name, element_tile, datatype) \ XNN_INTERNAL void fn_name( \ size_t n, \ const float* input, \ float* sum); - -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u64) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u72) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u80) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u96) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6) - -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3) -DECLARE_F32_RADDEXTEXP_UKERNEL_FUNCTION(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6) - +#include "f32-raddextexp/f32-raddextexp.h" +#undef XNN_UKERNEL #ifdef __cplusplus } /* extern "C" */ diff --git a/src/xnnpack/reference-utils.h b/src/xnnpack/reference-utils.h index 53bbb591e3b5..0f962981e88c 100644 --- a/src/xnnpack/reference-utils.h +++ b/src/xnnpack/reference-utils.h @@ -42,12 +42,12 @@ Result round_float_to_int(float x) { } template -float dequantize(T x, float scale, int32_t zero_point) { - return (static_cast(x) - static_cast(zero_point)) * scale; +float dequantize(T x, float scale, float zero_point) { + return (static_cast(x) - zero_point) * scale; } template -T quantize(float x, float inv_scale, int32_t zero_point) { +T quantize(float x, float inv_scale, float zero_point) { return round_float_to_int(x * inv_scale + zero_point); } diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h index cd4e10d2ff86..dc0c16cdc7c1 100644 --- a/src/xnnpack/subgraph.h +++ b/src/xnnpack/subgraph.h @@ -12,6 +12,7 @@ #include "xnnpack/allocation-type.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" +#include "xnnpack/config-types.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" #include "pthreadpool.h" @@ -30,7 +31,6 @@ #define XNN_INVALID_NODE_ID UINT32_MAX #define XNN_MAX_OPERATOR_OBJECTS 5 -#define XNN_MAX_SUBGRAPH_INPUT_OR_OUTPUTS 16 /// Disable fusion of nodes in subgraph. Fusion is enabled by default, set this flag to turn it off. #define XNN_FLAG_NO_OPERATOR_FUSION 0x80000000 @@ -123,11 +123,13 @@ struct xnn_value { uint32_t flags; /// Static initialization data. Must be null for non-static values. void* data; - /// Index of the Subgraph node that produced the value, or XNN_INVALID_NODE_ID is the Value is an external input. + /// Index of the Subgraph node that produced the value, or XNN_INVALID_NODE_ID + /// is the Value is an external input. uint32_t producer; /// Index of the first Node that consume the value, or XNN_INVALID_NODE_ID if the Value has no consumers within the /// graph (e.g. Value is an external output). uint32_t first_consumer; + bool all_consumers_types_same; /// Number of Nodes that consume the value. /// If multiple inputs in a Node refer to this Value as input, the Node is counted as consumer multiple times. /// If the Value is an external output, it counts as having an extra consumer. @@ -146,6 +148,10 @@ struct xnn_value { /// Used during analysis in xnn_subgraph_rewrite_for_fp16. /// Temporary buffer to convert static data to FP16. void* fp16_temp_data; + // Pointer to a `xnn_gemm_config` if this value is packed for a specific GEMM. + const struct xnn_gemm_config *gemm_config; + // If true, assume dimensions > 2 will be squashed to 2 dimensions. + bool squash_groups; // Pointer to original fp32 data if this value was converted from fp32 to fp16 (only for static values). This is used // for nodes like Convolution, where the filter is expected to be kept as fp32, but could have been converted to fp16 // if another node (like Subtraction) also consumed the weights. @@ -335,11 +341,15 @@ struct xnn_node { uint32_t layout_flags; uint32_t cluster_leader; // Number of filter parameters in all 1x1 Convolutions of the sparse cluster. - // This value is properly initialized only in sparse inference analysis of 1x1 Convolutions. + // This value is properly initialized only in sparse inference analysis of 1x1 + // Convolutions. size_t num_params; - // Number of zero filter parameters in all 1x1 Convolutions of the sparse cluster. - // This value is properly initialized only in sparse inference analysis of 1x1 Convolutions. + // Number of zero filter parameters in all 1x1 Convolutions of the sparse + // cluster. This value is properly initialized only in sparse inference + // analysis of 1x1 Convolutions. size_t num_zeroes; + // Pointer to the runtime operator corresponding to this node. + struct xnn_operator *op; // Factory function to create an operator object from the node. xnn_create_operator_fn create; // Function to reshape an operator using opdata. @@ -455,10 +465,6 @@ struct xnn_runtime { #ifdef XNN_SLINKY_AVAILABLE // Fields used by Slinky -- unused unless XNN_FLAG_SLINKY_ENABLED is set slinky_pipeline_t slinky_pipeline; - size_t slinky_num_inputs; - size_t slinky_num_outputs; - struct xnn_value* slinky_input_values[XNN_MAX_SUBGRAPH_INPUT_OR_OUTPUTS]; - struct xnn_value* slinky_output_values[XNN_MAX_SUBGRAPH_INPUT_OR_OUTPUTS]; #endif // XNN_SLINKY_AVAILABLE }; @@ -479,9 +485,6 @@ size_t xnn_tensor_get_size(const struct xnn_value* value); size_t xnn_tensor_get_size_by_id(xnn_subgraph_t subgraph, uint32_t value_id); -// Checks if a tensor shape is completely known. -bool xnn_tensor_shape_is_static(const struct xnn_value* value); - XNN_INLINE static size_t xnn_get_rounded_size(size_t size) { // We round it to XNN_EXTRA_BYTES to ensure that we can read more than the actual size of the tensor, and round it @@ -521,7 +524,8 @@ size_t xnn_shape_multiply_trailing_dims( size_t xnn_tensor_get_dynamic_quant_param_size(const struct xnn_value* value); XNN_INLINE static size_t xnn_tensor_get_rounded_dynamic_quant_param_size(const struct xnn_value *value) { - assert (value->datatype == xnn_datatype_qdint8); + assert(value->datatype == xnn_datatype_qdint8 || + value->datatype == xnn_datatype_qduint8); // We may read out of bounds for qparams. return xnn_get_rounded_size(value->quantization.dynamic_params_size diff --git a/src/xnnpack/vcvt.h b/src/xnnpack/vcvt.h index 2e684e30f1d7..94497cfdde6d 100644 --- a/src/xnnpack/vcvt.h +++ b/src/xnnpack/vcvt.h @@ -28,7 +28,6 @@ extern "C" { #include "qs8-vcvt/qs8-vcvt.h" #include "qu8-f32-vcvt/qu8-f32-vcvt.h" #include "qu8-vcvt/qu8-vcvt.h" -#include "s32-f32-vcvt/s32-f32-vcvt.h" #undef XNN_CVT_UKERNEL_WITH_PARAMS #ifdef __cplusplus diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 31a3ab05ee30..c95c2c9ebb74 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -13,7 +13,7 @@ load( ) load( "//:build_params.bzl", - "xnnpack_select_if", + # "xnnpack_select_if", "xnnpack_simd_copts_for_arch", "xnnpack_simd_f16_archs", "xnnpack_simd_f32_archs", @@ -74,6 +74,14 @@ xnnpack_cxx_library( deps = xnnpack_test_deps_for_library(), ) +xnnpack_cxx_library( + name = "runtime_flags", + testonly = True, + srcs = ["runtime-flags.cc"], + hdrs = ["runtime-flags.h"], + deps = xnnpack_test_deps_for_library(), +) + xnnpack_cxx_library( name = "operator_test_utils", testonly = True, @@ -148,21 +156,22 @@ xnnpack_cxx_library( ) ####################### Unit tests for microkernel lists ####################### -sh_test( - name = "microkernel_lists_test", - size = "small", - srcs = ["microkernel_lists_test.sh"], - data = [ - "//:cmake_microkernel_lists", - "//:generated_microkernel_lists", - "//gen:bzl_microkernel_lists", - ], - target_compatible_with = xnnpack_select_if( - "//build_config:linux", - [], - ["@platforms//:incompatible"], - ), -) +# TODO: b/381390736 - Reenable once fixed. +#sh_test( +# name = "microkernel_lists_test", +# size = "small", +# srcs = ["microkernel_lists_test.sh"], +# data = [ +# "//:cmake_microkernel_lists", +# "//:generated_microkernel_lists", +# "//gen:bzl_microkernel_lists", +# ], +# target_compatible_with = xnnpack_select_if( +# "//build_config:linux", +# [], +# ["@platforms//:incompatible"], +# ), +#) ######################### Unit tests for simd wrappers ######################### [xnnpack_unit_test( @@ -294,7 +303,6 @@ sh_test( "f32_f16_vcvt", "f32_qs8_vcvt", "f32_qu8_vcvt", - "s32_f32_vcvt", "qs8_f16_vcvt", "qs8_f32_vcvt", "qs8_vcvt", @@ -740,7 +748,6 @@ xnnpack_unit_test( name = "f32_raddextexp_test", srcs = [ "f32-raddextexp.cc", - "raddextexp-microkernel-tester.h", ], deps = MICROKERNEL_TEST_DEPS, ) @@ -916,6 +923,18 @@ xnnpack_unit_test( ], ) +xnnpack_unit_test( + name = "qp8_f32_qc8w_gemm_minmax_test", + timeout = "moderate", + srcs = [ + "qp8-f32-qc8w-gemm-minmax.cc", + ], + defines = xnnpack_kleidiai_defines(), + deps = MICROKERNEL_TEST_DEPS + [ + ":gemm_microkernel_tester", + ], +) + xnnpack_unit_test( name = "qs8_qc8w_gemm_minmax_fp32_test", timeout = "moderate", @@ -1065,15 +1084,6 @@ xnnpack_unit_test( deps = MICROKERNEL_TEST_DEPS, ) -xnnpack_unit_test( - name = "x8_zip_test", - srcs = [ - "x8-zip.cc", - "zip-microkernel-tester.h", - ], - deps = MICROKERNEL_TEST_DEPS, -) - xnnpack_unit_test( name = "x32_packb_test", srcs = [ @@ -1132,15 +1142,6 @@ xnnpack_unit_test( deps = MICROKERNEL_TEST_DEPS, ) -xnnpack_unit_test( - name = "x32_zip_test", - srcs = [ - "x32-zip.cc", - "zip-microkernel-tester.h", - ], - deps = MICROKERNEL_TEST_DEPS, -) - xnnpack_unit_test( name = "xx_fill_test", srcs = ["xx-fill.cc"], @@ -1215,16 +1216,7 @@ xnnpack_unit_test( "batch-matrix-multiply-operator-tester.h", ], shard_count = 2, - deps = OPERATOR_TEST_DEPS, -) - -xnnpack_unit_test( - name = "channel_shuffle_nc_test", - srcs = [ - "channel-shuffle-nc.cc", - "channel-shuffle-operator-tester.h", - ], - deps = OPERATOR_TEST_DEPS, + deps = OPERATOR_TEST_DEPS + ["//:microkernels_h"], ) xnnpack_unit_test( @@ -1547,13 +1539,13 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", ":unary_ops", "//:XNNPACK", "//:buffer", "//:datatype", "//:logging", "//:operator_utils", - "//:operators", "//:subgraph", ], ) @@ -1565,6 +1557,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", ":subgraph_unary_tester", "//:XNNPACK", "//:buffer", @@ -1609,9 +1602,11 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:allocation_type", "//:buffer", + "//:common", "//:math", "//:subgraph", ], @@ -1623,6 +1618,7 @@ xnnpack_unit_test( deps = [ ":operator_test_utils", ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:datatype", @@ -1641,6 +1637,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", @@ -1657,6 +1654,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", @@ -1674,6 +1672,7 @@ xnnpack_unit_test( "average-pooling-2d-reshape.cc", ], deps = [ + ":runtime_flags", "//:XNNPACK", "//:node_type", "//:subgraph", @@ -1687,6 +1686,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", @@ -1704,6 +1704,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:math", @@ -1726,6 +1727,7 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", @@ -1747,8 +1749,10 @@ xnnpack_unit_test( shard_count = 5, deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", + "//:internal", "//:math", "//:node_type", "//:operator_utils", @@ -1765,6 +1769,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:math", @@ -1782,6 +1787,7 @@ xnnpack_unit_test( deps = [ ":convolution_test_helpers", ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", @@ -1801,6 +1807,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:math", @@ -1822,13 +1829,11 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", - "//:config_hdrs", - "//:internal", "//:math", - "//:microkernels_h", "//:node_type", "//:operators", "//:requantization", @@ -1913,6 +1918,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:math", @@ -1931,6 +1937,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:math", @@ -1947,6 +1954,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:aligned_allocator", "//:common", @@ -1963,6 +1971,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:common", @@ -1981,6 +1990,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:subgraph", @@ -1994,6 +2004,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:math", @@ -2009,6 +2020,7 @@ xnnpack_unit_test( "transpose-reshape.cc", ], deps = [ + ":runtime_flags", "//:XNNPACK", "//:node_type", "//:subgraph", @@ -2022,6 +2034,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:node_type", @@ -2040,6 +2053,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:node_type", @@ -2058,6 +2072,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:subgraph", @@ -2073,6 +2088,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:subgraph", @@ -2088,6 +2104,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:buffer", "//:node_type", @@ -2119,6 +2136,7 @@ xnnpack_unit_test( ], deps = [ ":replicable_random_device", + ":runtime_flags", "//:XNNPACK", "//:allocation_type", "//:allocator", diff --git a/test/argmax-pooling-2d.cc b/test/argmax-pooling-2d.cc index 8fdf89f2e057..e236051be9f5 100644 --- a/test/argmax-pooling-2d.cc +++ b/test/argmax-pooling-2d.cc @@ -20,6 +20,7 @@ #include "xnnpack/subgraph.h" #include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "runtime-flags.h" namespace { inline size_t compute_output_dimension(size_t padded_input_dimension, size_t kernel_dimension) @@ -209,7 +210,7 @@ TEST_F(ArgmaxPoolingTestF32, matches_operator_api) subgraph, input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, pooling_height, pooling_width, input_id, output_value_id, output_index_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -269,7 +270,7 @@ TEST_F(ArgmaxPoolingTestF32, reshape_output) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/average-pooling-2d-reshape.cc b/test/average-pooling-2d-reshape.cc index 545f817e09ff..616ad2f494ea 100644 --- a/test/average-pooling-2d-reshape.cc +++ b/test/average-pooling-2d-reshape.cc @@ -13,6 +13,7 @@ #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/subgraph.h" +#include "runtime-flags.h" TEST(AveragePooling2DTestF32, Reshape) { @@ -58,7 +59,7 @@ TEST(AveragePooling2DTestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -121,7 +122,7 @@ TEST(AveragePooling2DTestF32, ReshapeWithPadding) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/average-pooling-2d.cc b/test/average-pooling-2d.cc index 4496b093586f..c8300b0cb436 100644 --- a/test/average-pooling-2d.cc +++ b/test/average-pooling-2d.cc @@ -22,6 +22,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template < typename InputType, @@ -255,7 +256,7 @@ TEST_F(AveragePoolingTestF16, matches_operator_api) subgraph, input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, output_min, output_max, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -329,7 +330,7 @@ TEST_F(AveragePoolingTestF32, matches_operator_api) subgraph, input_padding_top, input_padding_right, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, output_min, output_max, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/batch-matrix-multiply-nc.cc b/test/batch-matrix-multiply-nc.cc index 9402669bff7c..4f3e5805167b 100644 --- a/test/batch-matrix-multiply-nc.cc +++ b/test/batch-matrix-multiply-nc.cc @@ -199,6 +199,20 @@ TEST_P(BatchMatMulTest, TestQD8F32QC8W) { .TestQD8F32QC8W(); } +TEST_P(BatchMatMulTest, TestQP8F32QC8W) { + const BatchMatMulTesterParams& params = GetParam(); + BatchMatMulOperatorTester() + .batch_dims_a(params.batch_dims_a) + .batch_dims_b(params.batch_dims_b) + .m(params.m) + .k(params.k) + .n(params.n) + .transpose_b(params.transpose_b) + .iterations(params.iterations) + .expected_status_reshape(params.expected_status_reshape) + .TestQP8F32QC8W(); +} + // Create tests for different batch sizes with different amounts of // broadcasting, with and without transposition. INSTANTIATE_TEST_SUITE_P( diff --git a/test/batch-matrix-multiply-operator-tester.h b/test/batch-matrix-multiply-operator-tester.h index 4451314e18c2..e6d6cf4bf867 100644 --- a/test/batch-matrix-multiply-operator-tester.h +++ b/test/batch-matrix-multiply-operator-tester.h @@ -19,9 +19,13 @@ #include #include "xnnpack.h" +#include "xnnpack/buffer.h" #include "xnnpack/common.h" +#include "xnnpack/config-types.h" +#include "xnnpack/config.h" +#include "xnnpack/internal.h" #include "xnnpack/math.h" -#include "xnnpack/buffer.h" +#include "xnnpack/packq.h" #include "replicable_random_device.h" class BatchMatMulOperatorTester { @@ -32,9 +36,7 @@ class BatchMatMulOperatorTester { return *this; } - size_t m() const { - return this->m_; - } + size_t m() const { return this->m_; } BatchMatMulOperatorTester& k(size_t k) { assert(k >= 1); @@ -42,9 +44,7 @@ class BatchMatMulOperatorTester { return *this; } - size_t k() const { - return this->k_; - } + size_t k() const { return this->k_; } BatchMatMulOperatorTester& n(size_t n) { assert(n >= 1); @@ -52,9 +52,7 @@ class BatchMatMulOperatorTester { return *this; } - size_t n() const { - return this->n_; - } + size_t n() const { return this->n_; } inline BatchMatMulOperatorTester& batch_dims_a( std::vector batch_dims_a) { @@ -87,18 +85,14 @@ class BatchMatMulOperatorTester { return *this; } - bool transpose_b() const { - return this->transpose_b_; - } + bool transpose_b() const { return this->transpose_b_; } BatchMatMulOperatorTester& iterations(size_t iterations) { this->iterations_ = iterations; return *this; } - size_t iterations() const { - return this->iterations_; - } + size_t iterations() const { return this->iterations_; } uint32_t flags() const { if (transpose_b()) { @@ -141,8 +135,8 @@ class BatchMatMulOperatorTester { } static void ComputeRefF16(size_t m, size_t k, size_t n, bool transpose_b, - const xnn_float16* input_a, const xnn_float16* input_b, - float* output_ref) { + const xnn_float16* input_a, + const xnn_float16* input_b, float* output_ref) { std::fill(output_ref, output_ref + m * n, 0.0f); if (transpose_b) { @@ -151,8 +145,7 @@ class BatchMatMulOperatorTester { for (size_t ni = 0; ni < n; ni++) { for (size_t ki = 0; ki < k; ki++) { output_ref[mi * n + ni] += - input_a[mi * k + ki] * - input_b[ni * k + ki]; + input_a[mi * k + ki] * input_b[ni * k + ki]; } } } @@ -162,8 +155,7 @@ class BatchMatMulOperatorTester { for (size_t ni = 0; ni < n; ni++) { for (size_t ki = 0; ki < k; ki++) { output_ref[mi * n + ni] += - input_a[mi * k + ki] * - input_b[ki * n + ni]; + input_a[mi * k + ki] * input_b[ki * n + ni]; } } } @@ -273,9 +265,9 @@ class BatchMatMulOperatorTester { } xnnpack::Buffer input_a(XNN_EXTRA_BYTES / sizeof(xnn_float16) + - batch_size_a * m() * k()); + batch_size_a * m() * k()); xnnpack::Buffer input_b(XNN_EXTRA_BYTES / sizeof(xnn_float16) + - batch_size_b * k() * n()); + batch_size_b * k() * n()); xnnpack::Buffer output(batch_size_output * m() * n()); xnnpack::Buffer output_ref(batch_size_output * m() * n()); @@ -293,7 +285,8 @@ class BatchMatMulOperatorTester { ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); xnn_operator_t batch_matrix_multiply_op = nullptr; - const xnn_status status = xnn_create_batch_matrix_multiply_nc_f16(flags(), &batch_matrix_multiply_op); + const xnn_status status = xnn_create_batch_matrix_multiply_nc_f16( + flags(), &batch_matrix_multiply_op); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } @@ -301,8 +294,9 @@ class BatchMatMulOperatorTester { ASSERT_NE(nullptr, batch_matrix_multiply_op); // Smart pointer to automatically delete batch_matrix_multiply_op. - std::unique_ptr auto_batch_matrix_multiply_op( - batch_matrix_multiply_op, xnn_delete_operator); + std::unique_ptr + auto_batch_matrix_multiply_op(batch_matrix_multiply_op, + xnn_delete_operator); size_t workspace_size = 0; size_t workspace_alignment = 0; @@ -326,8 +320,8 @@ class BatchMatMulOperatorTester { batch_matrix_multiply_op, workspace.data(), input_a.data(), input_b.data(), output.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(batch_matrix_multiply_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_run_operator(batch_matrix_multiply_op, + /*threadpool=*/nullptr)); VerifyF16(output, output_ref); } @@ -356,9 +350,9 @@ class BatchMatMulOperatorTester { } xnnpack::Buffer input_a(XNN_EXTRA_BYTES / sizeof(float) + - batch_size_a * m() * k()); + batch_size_a * m() * k()); xnnpack::Buffer input_b(XNN_EXTRA_BYTES / sizeof(float) + - batch_size_b * k() * n()); + batch_size_b * k() * n()); xnnpack::Buffer output(batch_size_output * m() * n()); xnnpack::Buffer output_ref(batch_size_output * m() * n()); @@ -412,7 +406,8 @@ class BatchMatMulOperatorTester { ASSERT_NE(workspace_size, 0); ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); } - xnnpack::Buffer workspace(workspace_size); + xnnpack::Buffer workspace( + workspace_size); // TODO(b/372731180): This should probably be initialized by the // operator. std::fill(workspace.begin(), workspace.end(), 0); @@ -430,6 +425,53 @@ class BatchMatMulOperatorTester { } } + void ComputeQC8W( + xnnpack::Buffer& input_b, size_t batch_size_b, + xnnpack::Buffer& input_b_qc8, + xnnpack::Buffer& channelwise_scale_b) const { + if (transpose_b_) { + for (size_t b = 0; b < batch_size_b; b++) { + for (size_t c = 0; c < n(); c++) { + const size_t offset = b * n() * k() + c * k(); + float max_abs = 0.0f; + for (size_t i = 0; i < k(); i++) { + max_abs = std::max(max_abs, std::abs(input_b[offset + i])); + } + if (max_abs == 0.0f) { + max_abs = 1.0f; + } + const float scale = max_abs / std::numeric_limits::max(); + const float inv_scale = 1.0f / scale; + for (size_t i = 0; i < k(); i++) { + input_b_qc8[offset + i] = static_cast( + std::round(input_b[offset + i] * inv_scale)); + } + channelwise_scale_b[b * n() + c] = scale; + } + } + } else { + for (size_t b = 0; b < batch_size_b; b++) { + const size_t bnk = b * n() * k(); + for (size_t c = 0; c < n(); c++) { + float max_abs = 0.0f; + for (size_t i = 0; i < k(); i++) { + max_abs = std::max(max_abs, std::abs(input_b[bnk + i * n() + c])); + } + if (max_abs == 0.0f) { + max_abs = 1.0f; + } + const float scale = max_abs / std::numeric_limits::max(); + const float inv_scale = 1.0f / scale; + for (size_t i = 0; i < k(); i++) { + input_b_qc8[bnk + i * n() + c] = static_cast( + std::round(input_b[bnk + i * n() + c] * inv_scale)); + } + channelwise_scale_b[b * n() + c] = scale; + } + } + } + } + void TestQD8F32QC8W() const { ASSERT_EQ(batch_dims_a().size(), batch_dims_b().size()); const size_t num_batch_dims = batch_dims_a().size(); @@ -454,9 +496,9 @@ class BatchMatMulOperatorTester { } xnnpack::Buffer input_a(XNN_EXTRA_BYTES / sizeof(float) + - batch_size_a * m() * k()); + batch_size_a * m() * k()); xnnpack::Buffer input_b(XNN_EXTRA_BYTES / sizeof(float) + - batch_size_b * k() * n()); + batch_size_b * k() * n()); xnnpack::Buffer output(batch_size_output * m() * n()); xnnpack::Buffer output_ref(batch_size_output * m() * n()); @@ -473,7 +515,7 @@ class BatchMatMulOperatorTester { xnnpack::Buffer quantization_params( batch_size_a * m() + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer input_a_qd8(batch_size_a * m() * k() + - XNN_EXTRA_BYTES / sizeof(int8_t)); + XNN_EXTRA_BYTES / sizeof(int8_t)); xnn_operator_t convert_op = nullptr; xnn_status status = xnn_create_convert_nc_f32_qd8( /*flags=*/0, &convert_op); @@ -496,51 +538,11 @@ class BatchMatMulOperatorTester { // Compute the channelwise quantized input_b. xnnpack::Buffer input_b_qc8(XNN_EXTRA_BYTES / sizeof(int8_t) + - batch_size_b * k() * n()); - xnnpack::Buffer channelwise_scale_b(XNN_EXTRA_BYTES / sizeof(float) + - batch_size_b * n()); - if (transpose_b_) { - for (size_t b = 0; b < batch_size_b; b++) { - for (size_t c = 0; c < n(); c++) { - const size_t offset = b * n() * k() + c * k(); - float max_abs = 0.0f; - for (size_t i = 0; i < k(); i++) { - max_abs = std::max(max_abs, std::abs(input_b[offset + i])); - } - if (max_abs == 0.0f) { - max_abs = 1.0f; - } - const float scale = max_abs / std::numeric_limits::max(); - const float inv_scale = 1.0f / scale; - for (size_t i = 0; i < k(); i++) { - input_b_qc8[offset + i] = static_cast( - std::round(input_b[offset + i] * inv_scale)); - } - channelwise_scale_b[b * n() + c] = scale; - } - } - } else { - for (size_t b = 0; b < batch_size_b; b++) { - const size_t bnk = b * n() * k(); - for (size_t c = 0; c < n(); c++) { - float max_abs = 0.0f; - for (size_t i = 0; i < k(); i++) { - max_abs = std::max(max_abs, std::abs(input_b[bnk + i * n() + c])); - } - if (max_abs == 0.0f) { - max_abs = 1.0f; - } - const float scale = max_abs / std::numeric_limits::max(); - const float inv_scale = 1.0f / scale; - for (size_t i = 0; i < k(); i++) { - input_b_qc8[bnk + i * n() + c] = static_cast( - std::round(input_b[bnk + i * n() + c] * inv_scale)); - } - channelwise_scale_b[b * n() + c] = scale; - } - } - } - + batch_size_b * k() * n()); + xnnpack::Buffer channelwise_scale_b( + XNN_EXTRA_BYTES / sizeof(float) + batch_size_b * n()); + ComputeQC8W(input_b, batch_size_b, input_b_qc8, + channelwise_scale_b); // Compute reference results. ComputeReference(batch_dims_output, input_a.data(), input_b.data(), output_ref.data(), ComputeRefF32); @@ -583,6 +585,124 @@ class BatchMatMulOperatorTester { } } + void TestQP8F32QC8W() const { + const struct xnn_gemm_config* gemm_config = + xnn_init_qp8_f32_qc8w_gemm_config(); + if (gemm_config == nullptr) { + GTEST_SKIP(); + } + + ASSERT_EQ(batch_dims_a().size(), batch_dims_b().size()); + const size_t num_batch_dims = batch_dims_a().size(); + + xnnpack::ReplicableRandomDevice rng; + std::uniform_real_distribution f32dist(range_f32_.first, + range_f32_.second); + + size_t batch_size_a = 1; + for (int k = 0; k < num_batch_dims; k++) { + batch_size_a *= batch_dims_a()[k]; + } + size_t batch_size_b = 1; + for (int k = 0; k < num_batch_dims; k++) { + batch_size_b *= batch_dims_b()[k]; + } + std::vector batch_dims_output(num_batch_dims); + size_t batch_size_output = 1; + for (int k = 0; k < num_batch_dims; k++) { + batch_dims_output[k] = std::max(batch_dims_a()[k], batch_dims_b()[k]); + batch_size_output *= batch_dims_output[k]; + } + + xnnpack::Buffer input_a(XNN_EXTRA_BYTES / sizeof(float) + + batch_size_a * m() * k()); + xnnpack::Buffer input_b(XNN_EXTRA_BYTES / sizeof(float) + + batch_size_b * k() * n()); + xnnpack::Buffer output(batch_size_output * m() * n()); + xnnpack::Buffer output_ref(batch_size_output * m() * n()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input_a.begin(), input_a.end(), + [&]() { return f32dist(rng); }); + std::generate(input_b.begin(), input_b.end(), + [&]() { return f32dist(rng); }); + + ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); + + // Create the dynamically quantized input data. + xnnpack::Buffer input_a_qp8( + batch_size_a * + xnn_x8_packq_f32qp8_gemm_packed_size(gemm_config, m(), k()) + + XNN_EXTRA_BYTES / sizeof(int8_t)); + xnn_operator_t convert_op = nullptr; + xnn_status status = xnn_create_convert_nc_f32_qp8( + /*flags=*/0, gemm_config, &convert_op); + std::unique_ptr + auto_convert_op(convert_op, xnn_delete_operator); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, convert_op); + ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8( + convert_op, batch_size_a, m(), k(), k(), + /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qp8(convert_op, input_a.data(), + input_a_qp8.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + + // Compute the channelwise quantized input_b. + xnnpack::Buffer input_b_qc8(XNN_EXTRA_BYTES / sizeof(int8_t) + + batch_size_b * k() * n()); + xnnpack::Buffer channelwise_scale_b( + XNN_EXTRA_BYTES / sizeof(float) + batch_size_b * n()); + ComputeQC8W(input_b, batch_size_b, input_b_qc8, + channelwise_scale_b); + + // Compute reference results. + ComputeReference(batch_dims_output, input_a.data(), input_b.data(), + output_ref.data(), ComputeRefF32); + + // Create, setup, run, and destroy Fully Connected operator. + xnn_operator_t batch_matrix_multiply_op = nullptr; + + status = xnn_create_batch_matrix_multiply_nc_qp8_f32_qc8w( + batch_size_b, k(), n(), input_b_qc8.data(), + channelwise_scale_b.data(), flags(), &batch_matrix_multiply_op); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, batch_matrix_multiply_op); + + // Smart pointer to automatically delete batch_matrix_multiply_op. + std::unique_ptr + auto_batch_matrix_multiply_op(batch_matrix_multiply_op, + xnn_delete_operator); + + ASSERT_EQ(expected_status_reshape(), + xnn_reshape_batch_matrix_multiply_nc_qp8_f32_qc8w( + batch_matrix_multiply_op, num_batch_dims, + batch_dims_a().data(), batch_dims_b().data(), m(), k(), n(), + /*threadpool=*/nullptr)); + if (expected_status_reshape() != xnn_status_success) { + return; + } + + ASSERT_EQ( + xnn_status_success, + xnn_setup_batch_matrix_multiply_nc_qp8_f32_qc8w( + batch_matrix_multiply_op, input_a_qp8.data(), output.data())); + + ASSERT_EQ(xnn_status_success, xnn_run_operator(batch_matrix_multiply_op, + /*threadpool=*/nullptr)); + + VerifyQD8F32QC8W(output, output_ref); + } + } + void VerifyF16(const xnnpack::Buffer& output, const xnnpack::Buffer& output_ref) const { const size_t batch_size_output = output.size() / (m() * n()); diff --git a/test/batch-matrix-multiply.cc b/test/batch-matrix-multiply.cc index 2b8fb8af5f0d..50fbac559338 100644 --- a/test/batch-matrix-multiply.cc +++ b/test/batch-matrix-multiply.cc @@ -6,9 +6,9 @@ #include // For std::generate. #include // For std::array. #include -#include #include // For size_t. #include // For uint32_t. +#include #include #include #include @@ -28,6 +28,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class BatchMatrixMultiplyTestBase : public ::testing::Test { @@ -40,7 +41,7 @@ class BatchMatrixMultiplyTestBase : public ::testing::Test { -std::numeric_limits::max(), std::numeric_limits::max()); auto shape_dist = - std::uniform_int_distribution(4, XNN_MAX_TENSOR_DIMS); + std::uniform_int_distribution(2, XNN_MAX_TENSOR_DIMS); auto broadcast_dist = std::uniform_int_distribution(0, 4); dim_dist = std::uniform_int_distribution(5, 15); @@ -51,7 +52,6 @@ class BatchMatrixMultiplyTestBase : public ::testing::Test { // where G is an integer multiple of H. size_t num_input_dims = shape_dist(rng); input1_dims = RandomShape(num_input_dims); - assert(input1_dims.size() >= 3); m = input1_dims[num_input_dims - 2]; k = input1_dims.back(); @@ -330,7 +330,7 @@ TEST_F(BatchMatrixMultiplyTestF16, matches_operator_api) xnn_define_batch_matrix_multiply(subgraph, input1_id, input2_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -419,7 +419,7 @@ TEST_F(BatchMatrixMultiplyTestF32, matches_operator_api) xnn_define_batch_matrix_multiply(subgraph, input1_id, input2_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -508,7 +508,7 @@ TEST_F(BatchMatrixMultiplyTestF32, matches_operator_api_transposed) xnn_define_batch_matrix_multiply(subgraph, input1_id, input2_id, output_id, /*flags=*/XNN_FLAG_TRANSPOSE_B)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -688,8 +688,9 @@ TEST_F(BatchMatrixMultiplyTestQD8ToF32, matches_operator_api) { ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); // Define the ops. - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input1_f32_id, - input1_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input1_f32_id, input1_id, /*flags=*/0)); ASSERT_EQ(xnn_status_success, xnn_define_batch_matrix_multiply(subgraph, input1_id, input2_id, output_id, /*flags=*/0)); @@ -697,7 +698,7 @@ TEST_F(BatchMatrixMultiplyTestQD8ToF32, matches_operator_api) { xnn_runtime_t runtime = nullptr; ASSERT_EQ( xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime( runtime, xnn_delete_runtime); @@ -708,9 +709,12 @@ TEST_F(BatchMatrixMultiplyTestQD8ToF32, matches_operator_api) { xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - // Check outputs match. + float max_abs_val = 0.0f; for (size_t i = 0; i < operator_output.size(); i++) { - ASSERT_EQ(subgraph_output[i], operator_output[i]) + max_abs_val = std::max(max_abs_val, std::abs(operator_output[i])); + } + for (size_t i = 0; i < operator_output.size(); i++) { + ASSERT_NEAR(operator_output[i], subgraph_output[i], max_abs_val * 2.0e-3) << " at index " << i << " of " << operator_output.size(); } } @@ -768,7 +772,7 @@ void DefineAndReshapeBatchMatrixMultiplySubgraph( xnn_runtime_t runtime = nullptr; ASSERT_EQ( xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr clean_up_subgraph(subgraph, xnn_delete_subgraph); @@ -801,7 +805,7 @@ TEST(BatchMatrixMultiplyReshapeTest, reshape_input1) { ASSERT_EQ(xnn_status_success, status); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/bf16-gemm-minmax.cc b/test/bf16-gemm-minmax.cc index be2e0fc846e3..c54e9440ef16 100644 --- a/test/bf16-gemm-minmax.cc +++ b/test/bf16-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, xnn_init_bf16_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params, @@ -427,6 +412,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params, @@ -465,6 +452,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params, @@ -484,6 +472,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, xnn_init_bf16_minmax_scalar_params, @@ -506,6 +495,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params, @@ -525,6 +515,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params, @@ -544,6 +535,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params, @@ -563,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, xnn_init_bf16_minmax_scalar_params, @@ -582,6 +575,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params, @@ -601,6 +595,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params, @@ -620,6 +615,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params, @@ -639,6 +635,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params, @@ -658,6 +655,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, xnn_init_bf16_minmax_scalar_params, @@ -677,6 +675,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params, @@ -696,6 +695,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params, @@ -715,6 +715,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params, @@ -734,6 +735,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params, @@ -753,6 +755,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, xnn_init_bf16_minmax_scalar_params, diff --git a/test/binary-elementwise-nd.cc b/test/binary-elementwise-nd.cc index b087b5adbab6..bb7373846b9b 100644 --- a/test/binary-elementwise-nd.cc +++ b/test/binary-elementwise-nd.cc @@ -25,6 +25,7 @@ #include #include "xnnpack.h" #include "xnnpack/buffer.h" +#include "xnnpack/common.h" #include "xnnpack/datatype.h" #include "xnnpack/log.h" #include "xnnpack/math.h" @@ -357,8 +358,8 @@ class BinaryElementwiseOperatorTester { xnnpack::Buffer input2(XNN_EXTRA_BYTES + num_input2_elements()); xnnpack::Buffer output(num_output_elements); for (size_t iteration = 0; iteration < iterations(); iteration++) { - randomize_buffer(datatype(), rng, dist, input1); - randomize_buffer(datatype(), rng, dist, input2); + xnnpack::randomize_buffer(datatype(), rng, dist, input1); + xnnpack::randomize_buffer(datatype(), rng, dist, input2); if (mode == RunMode::kCreateReshapeRun) { // Create, setup, run, and destroy a binary elementwise operator. @@ -486,8 +487,8 @@ class BinaryElementwiseOperatorTester { } // ValidateResults for integral (but non-quantized) types. - template ::value>::type* = nullptr> + template ::value>::type* = nullptr> void ValidateResults( const xnnpack::Buffer& input1, const std::array& input1_strides, @@ -497,7 +498,7 @@ class BinaryElementwiseOperatorTester { const std::array& output_strides, const std::array& output_dims) { // Verify results. - static_assert(!xnnpack::is_quantized::value); + static_assert(!xnnpack::is_quantized::value, ""); MinMaxLow limits = DatatypeMinMaxLow(datatype()); for (size_t i = 0; i < output_dims[0]; i++) { for (size_t j = 0; j < output_dims[1]; j++) { @@ -546,8 +547,8 @@ class BinaryElementwiseOperatorTester { } // ValidateResults for all other types (float variants). - template ::value>::type* = nullptr> + template ::value>::type* = nullptr> void ValidateResults( const xnnpack::Buffer& input1, const std::array& input1_strides, @@ -557,7 +558,7 @@ class BinaryElementwiseOperatorTester { const std::array& output_strides, const std::array& output_dims) { // Verify results. - static_assert(!xnnpack::is_quantized::value); + static_assert(!xnnpack::is_quantized::value, ""); MinMaxLow limits = DatatypeMinMaxLow(datatype()); for (size_t i = 0; i < output_dims[0]; i++) { for (size_t j = 0; j < output_dims[1]; j++) { diff --git a/test/binary.cc b/test/binary.cc index 965324c4c9a5..3399f5a22812 100644 --- a/test/binary.cc +++ b/test/binary.cc @@ -31,6 +31,7 @@ #include "xnnpack/subgraph.h" #include "operator-test-utils.h" #include "replicable_random_device.h" +#include "runtime-flags.h" using ::testing::Combine; using ::testing::ValuesIn; @@ -258,7 +259,7 @@ void MatchesOperatorApi(xnn_datatype datatype, xnn_binary_operator binary_op) { xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_parameter) { GTEST_SKIP(); } @@ -365,7 +366,7 @@ void Reshape(xnn_datatype datatype, xnn_binary_operator binary_op) { xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_parameter) { GTEST_SKIP(); } @@ -449,7 +450,7 @@ void ReshapeBroadcastDim0(xnn_datatype datatype, xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_parameter) { GTEST_SKIP(); } @@ -532,7 +533,7 @@ void ReshapeBroadcast1D(xnn_datatype datatype, xnn_binary_operator binary_op) { xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_parameter) { GTEST_SKIP(); } @@ -615,7 +616,7 @@ void ReshapeBroadcast2D(xnn_datatype datatype, xnn_binary_operator binary_op) { xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_parameter) { GTEST_SKIP(); } @@ -698,7 +699,7 @@ void DegenerateDimension(xnn_datatype datatype, xnn_binary_operator binary_op) { xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_parameter) { GTEST_SKIP(); } diff --git a/test/channel-shuffle-nc.cc b/test/channel-shuffle-nc.cc deleted file mode 100644 index b84e9f993d48..000000000000 --- a/test/channel-shuffle-nc.cc +++ /dev/null @@ -1,504 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include - -#include -#include "channel-shuffle-operator-tester.h" - -TEST(CHANNEL_SHUFFLE_NC_X8, two_groups_unit_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(2) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, three_groups_unit_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(3) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, four_groups_unit_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(4) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, many_groups_unit_batch) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(groups) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, two_groups_small_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, three_groups_small_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, four_groups_small_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, many_groups_small_batch) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .iterations(3) - .TestX8(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, two_groups_small_batch_with_input_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .input_stride(511) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, three_groups_small_batch_with_input_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .input_stride(511) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, four_groups_small_batch_with_input_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .input_stride(511) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, many_groups_small_batch_with_input_stride) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .input_stride(1007) - .iterations(3) - .TestX8(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, two_groups_small_batch_with_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .output_stride(513) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, three_groups_small_batch_with_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .output_stride(513) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, four_groups_small_batch_with_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .output_stride(513) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, many_groups_small_batch_with_output_stride) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .output_stride(1111) - .iterations(3) - .TestX8(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, two_groups_small_batch_with_input_and_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .input_stride(511) - .output_stride(513) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, three_groups_small_batch_with_input_and_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .input_stride(511) - .output_stride(513) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, four_groups_small_batch_with_input_and_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .input_stride(511) - .output_stride(513) - .iterations(3) - .TestX8(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X8, many_groups_small_batch_with_input_and_output_stride) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .input_stride(1007) - .output_stride(1111) - .iterations(3) - .TestX8(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, two_groups_unit_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(2) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, three_groups_unit_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(3) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, four_groups_unit_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(4) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, many_groups_unit_batch) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(1) - .groups(groups) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, two_groups_small_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, three_groups_small_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, four_groups_small_batch) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, many_groups_small_batch) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .iterations(3) - .TestX32(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, two_groups_small_batch_with_input_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .input_stride(511) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, three_groups_small_batch_with_input_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .input_stride(511) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, four_groups_small_batch_with_input_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .input_stride(511) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, many_groups_small_batch_with_input_stride) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .input_stride(1007) - .iterations(3) - .TestX32(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, two_groups_small_batch_with_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .output_stride(513) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, three_groups_small_batch_with_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .output_stride(513) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, four_groups_small_batch_with_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .output_stride(513) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, many_groups_small_batch_with_output_stride) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .output_stride(1111) - .iterations(3) - .TestX32(); - } - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, two_groups_small_batch_with_input_and_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(2) - .group_channels(group_channels) - .input_stride(511) - .output_stride(513) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, three_groups_small_batch_with_input_and_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(3) - .group_channels(group_channels) - .input_stride(511) - .output_stride(513) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, four_groups_small_batch_with_input_and_output_stride) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(4) - .group_channels(group_channels) - .input_stride(511) - .output_stride(513) - .iterations(3) - .TestX32(); - } -} - -TEST(CHANNEL_SHUFFLE_NC_X32, many_groups_small_batch_with_input_and_output_stride) { - for (size_t groups = 5; groups < 12; groups += 3) { - for (size_t group_channels = 1; group_channels < 100; group_channels += 15) { - ChannelShuffleOperatorTester() - .batch_size(3) - .groups(groups) - .group_channels(group_channels) - .input_stride(1007) - .output_stride(1111) - .iterations(3) - .TestX32(); - } - } -} diff --git a/test/channel-shuffle-operator-tester.h b/test/channel-shuffle-operator-tester.h deleted file mode 100644 index e3ac7117df8e..000000000000 --- a/test/channel-shuffle-operator-tester.h +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// All rights reserved. -// -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include "xnnpack.h" -#include "xnnpack/buffer.h" -#include "replicable_random_device.h" - -class ChannelShuffleOperatorTester { - public: - ChannelShuffleOperatorTester& groups(size_t groups) { - assert(groups != 0); - this->groups_ = groups; - return *this; - } - - size_t groups() const { - return this->groups_; - } - - ChannelShuffleOperatorTester& group_channels(size_t group_channels) { - assert(group_channels != 0); - this->group_channels_ = group_channels; - return *this; - } - - size_t group_channels() const { - return this->group_channels_; - } - - size_t channels() const { - return groups() * group_channels(); - } - - ChannelShuffleOperatorTester& input_stride(size_t input_stride) { - assert(input_stride != 0); - this->input_stride_ = input_stride; - return *this; - } - - size_t input_stride() const { - if (this->input_stride_ == 0) { - return channels(); - } else { - assert(this->input_stride_ >= channels()); - return this->input_stride_; - } - } - - ChannelShuffleOperatorTester& output_stride(size_t output_stride) { - assert(output_stride != 0); - this->output_stride_ = output_stride; - return *this; - } - - size_t output_stride() const { - if (this->output_stride_ == 0) { - return channels(); - } else { - assert(this->output_stride_ >= channels()); - return this->output_stride_; - } - } - - ChannelShuffleOperatorTester& batch_size(size_t batch_size) { - assert(batch_size != 0); - this->batch_size_ = batch_size; - return *this; - } - - size_t batch_size() const { - return this->batch_size_; - } - - ChannelShuffleOperatorTester& iterations(size_t iterations) { - this->iterations_ = iterations; - return *this; - } - - size_t iterations() const { - return this->iterations_; - } - - void TestX8() const { - xnnpack::ReplicableRandomDevice rng; - std::uniform_int_distribution u8dist( - std::numeric_limits::min(), std::numeric_limits::max()); - - xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels()); - xnnpack::Buffer output((batch_size() - 1) * output_stride() + channels()); - for (size_t iteration = 0; iteration < iterations(); iteration++) { - std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); - - // Create, setup, run, and destroy Channel Shuffle operator. - ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); - xnn_operator_t channel_shuffle_op = nullptr; - - ASSERT_EQ(xnn_status_success, - xnn_create_channel_shuffle_nc_x8( - groups(), group_channels(), - input_stride(), output_stride(), - 0, &channel_shuffle_op)); - ASSERT_NE(nullptr, channel_shuffle_op); - - // Smart pointer to automatically delete channel_shuffle_op. - std::unique_ptr auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator); - - ASSERT_EQ(xnn_status_success, - xnn_reshape_channel_shuffle_nc_x8( - channel_shuffle_op, - batch_size(), - /*threadpool=*/nullptr)); - - ASSERT_EQ(xnn_status_success, - xnn_setup_channel_shuffle_nc_x8( - channel_shuffle_op, - input.data(), output.data())); - - ASSERT_EQ(xnn_status_success, - xnn_run_operator(channel_shuffle_op, /*threadpool=*/nullptr)); - - // Verify results. - for (size_t i = 0; i < batch_size(); i++) { - for (size_t g = 0; g < groups(); g++) { - for (size_t c = 0; c < group_channels(); c++) { - ASSERT_EQ(int32_t(input[i * input_stride() + g * group_channels() + c]), - int32_t(output[i * output_stride() + c * groups() + g])) - << "batch index " << i << ", group " << g << ", channel " << c; - } - } - } - } - } - - void TestX32() const { - xnnpack::ReplicableRandomDevice rng; - std::uniform_int_distribution u32dist; - - xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(uint32_t) + (batch_size() - 1) * input_stride() + channels()); - xnnpack::Buffer output((batch_size() - 1) * output_stride() + channels()); - for (size_t iteration = 0; iteration < iterations(); iteration++) { - std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); }); - - // Create, setup, run, and destroy Channel Shuffle operator. - ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); - xnn_operator_t channel_shuffle_op = nullptr; - - ASSERT_EQ(xnn_status_success, - xnn_create_channel_shuffle_nc_x32( - groups(), group_channels(), - input_stride(), output_stride(), - 0, &channel_shuffle_op)); - ASSERT_NE(nullptr, channel_shuffle_op); - - // Smart pointer to automatically delete channel_shuffle_op. - std::unique_ptr auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator); - - ASSERT_EQ(xnn_status_success, - xnn_reshape_channel_shuffle_nc_x32( - channel_shuffle_op, - batch_size(), - /*threadpool=*/nullptr)); - - ASSERT_EQ(xnn_status_success, - xnn_setup_channel_shuffle_nc_x32( - channel_shuffle_op, - input.data(), output.data())); - - ASSERT_EQ(xnn_status_success, - xnn_run_operator(channel_shuffle_op, /*threadpool=*/nullptr)); - - // Verify results. - for (size_t i = 0; i < batch_size(); i++) { - for (size_t g = 0; g < groups(); g++) { - for (size_t c = 0; c < group_channels(); c++) { - ASSERT_EQ(input[i * input_stride() + g * group_channels() + c], - output[i * output_stride() + c * groups() + g]) - << "batch index " << i << ", group " << g << ", channel " << c; - } - } - } - } - } - - private: - size_t groups_{1}; - size_t group_channels_{1}; - size_t batch_size_{1}; - size_t input_stride_{0}; - size_t output_stride_{0}; - size_t iterations_{15}; -}; diff --git a/test/concatenate2.cc b/test/concatenate2.cc index c13e9b01b22c..b0b644e795f2 100644 --- a/test/concatenate2.cc +++ b/test/concatenate2.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class Concatenate2Test : public ::testing::Test { protected: @@ -363,7 +364,7 @@ TEST_F(Concatenate2TestQS8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -435,7 +436,7 @@ TEST_F(Concatenate2TestQU8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -504,7 +505,7 @@ TEST_F(Concatenate2TestF16, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -573,7 +574,7 @@ TEST_F(Concatenate2TestF32, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_concatenate2(subgraph, axis, input1_id, input2_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -631,7 +632,7 @@ TEST_F(Concatenate2TestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/concatenate3.cc b/test/concatenate3.cc index 8eca94cd82de..2d74c400da7a 100644 --- a/test/concatenate3.cc +++ b/test/concatenate3.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class Concatenate3Test : public ::testing::Test { protected: @@ -435,7 +436,7 @@ TEST_F(Concatenate3TestQS8, matches_operator_api) xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -528,7 +529,7 @@ TEST_F(Concatenate3TestQU8, matches_operator_api) xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -617,7 +618,7 @@ TEST_F(Concatenate3TestF16, matches_operator_api) xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -706,7 +707,7 @@ TEST_F(Concatenate3TestF32, matches_operator_api) xnn_define_concatenate3(subgraph, axis, input1_id, input2_id, input3_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -772,7 +773,7 @@ TEST_F(Concatenate3TestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/concatenate4.cc b/test/concatenate4.cc index 473fc0bb5fda..0f7b5dd2c1e4 100644 --- a/test/concatenate4.cc +++ b/test/concatenate4.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class Concatenate4Test : public ::testing::Test { protected: @@ -494,7 +495,7 @@ TEST_F(Concatenate4TestQS8, matches_operator_api) xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -606,7 +607,7 @@ TEST_F(Concatenate4TestQU8, matches_operator_api) xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -713,7 +714,7 @@ TEST_F(Concatenate4TestF16, matches_operator_api) xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -820,7 +821,7 @@ TEST_F(Concatenate4TestF32, matches_operator_api) xnn_define_concatenate4(subgraph, axis, input1_id, input2_id, input3_id, input4_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -895,7 +896,7 @@ TEST_F(Concatenate4TestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/concatenate5.cc b/test/concatenate5.cc index aa3c8dc34b3d..e33072152d94 100644 --- a/test/concatenate5.cc +++ b/test/concatenate5.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class Concatenate5Test : public ::testing::Test { protected: @@ -554,7 +555,7 @@ TEST_F(Concatenate5TestQS8, matches_operator_api) xnn_define_concatenate5(subgraph, axis, input1_id, input2_id, input3_id, input4_id, input5_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -684,7 +685,7 @@ TEST_F(Concatenate5TestQU8, matches_operator_api) xnn_define_concatenate5(subgraph, axis, input1_id, input2_id, input3_id, input4_id, input5_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -808,7 +809,7 @@ TEST_F(Concatenate5TestF16, matches_operator_api) xnn_define_concatenate5(subgraph, axis, input1_id, input2_id, input3_id, input4_id, input5_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -932,7 +933,7 @@ TEST_F(Concatenate5TestF32, matches_operator_api) xnn_define_concatenate5(subgraph, axis, input1_id, input2_id, input3_id, input4_id, input5_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1015,7 +1016,7 @@ TEST_F(Concatenate5TestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/convert-nc.cc b/test/convert-nc.cc index b796c9f0fa71..9c9004913593 100644 --- a/test/convert-nc.cc +++ b/test/convert-nc.cc @@ -12,16 +12,15 @@ #include #include #include -#include #include #include "xnnpack.h" +#include "xnnpack/buffer.h" #include "xnnpack/config-types.h" #include "xnnpack/config.h" #include "xnnpack/internal.h" #include "xnnpack/math.h" #include "xnnpack/packq.h" -#include "xnnpack/buffer.h" #include "replicable_random_device.h" class ConvertOperatorTester { @@ -32,9 +31,7 @@ class ConvertOperatorTester { return *this; } - size_t channels() const { - return this->channels_; - } + size_t channels() const { return this->channels_; } ConvertOperatorTester& input_stride(size_t input_stride) { assert(input_stride != 0); @@ -72,9 +69,7 @@ class ConvertOperatorTester { return *this; } - size_t batch_size() const { - return this->batch_size_; - } + size_t batch_size() const { return this->batch_size_; } ConvertOperatorTester& input_scale(float input_scale) { assert(input_scale >= 0.0f); @@ -83,9 +78,7 @@ class ConvertOperatorTester { return *this; } - float input_scale() const { - return this->input_scale_; - } + float input_scale() const { return this->input_scale_; } ConvertOperatorTester& output_scale(float output_scale) { assert(output_scale >= 0.0f); @@ -94,38 +87,32 @@ class ConvertOperatorTester { return *this; } - float output_scale() const { - return this->output_scale_; - } + float output_scale() const { return this->output_scale_; } ConvertOperatorTester& zero_point(int16_t zero_point) { this->zero_point_ = zero_point; return *this; } - int16_t zero_point() const { - return this->zero_point_; - } + int16_t zero_point() const { return this->zero_point_; } ConvertOperatorTester& iterations(size_t iterations) { this->iterations_ = iterations; return *this; } - size_t iterations() const { - return this->iterations_; - } + size_t iterations() const { return this->iterations_; } void TestF16toQD8() const { xnnpack::ReplicableRandomDevice rng; xnnpack::Buffer input_float((batch_size() - 1) * input_stride() + - channels()); + channels()); xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(xnn_float16) + - (batch_size() - 1) * input_stride() + - channels()); + (batch_size() - 1) * input_stride() + + channels()); xnnpack::Buffer output((batch_size() - 1) * output_stride() + - channels()); + channels()); xnnpack::Buffer quantization_params( batch_size() + XNN_EXTRA_QUANTIZATION_PARAMS); std::uniform_real_distribution range_dist(-10, 10); @@ -139,8 +126,7 @@ class ConvertOperatorTester { std::generate(input_float.begin(), input_float.end(), [&]() { return f32dist(rng); }); std::copy(input_float.begin(), input_float.end(), input.begin()); - std::copy(input.begin(), input.begin() + channels(), - input_float.begin()); + std::copy(input.begin(), input.begin() + channels(), input_float.begin()); // Create, setup, run, and destroy Convert operator. ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); @@ -198,46 +184,130 @@ class ConvertOperatorTester { xnnpack::ReplicableRandomDevice rng; xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(float) + - (batch_size() - 1) * input_stride() + channels()); - xnnpack::Buffer output((batch_size() - 1) * output_stride() + channels()); - xnnpack::Buffer quantization_params(batch_size() + XNN_EXTRA_QUANTIZATION_PARAMS); + (batch_size() - 1) * input_stride() + + channels()); + xnnpack::Buffer output((batch_size() - 1) * output_stride() + + channels()); + xnnpack::Buffer quantization_params( + batch_size() + XNN_EXTRA_QUANTIZATION_PARAMS); std::uniform_real_distribution range_dist(-100000, 100000); for (size_t iteration = 0; iteration < iterations(); iteration++) { const float first_val = range_dist(rng); const float second_val = range_dist(rng); - std::uniform_real_distribution f32dist(std::min(first_val, second_val), std::max(first_val, second_val)); + std::uniform_real_distribution f32dist( + std::min(first_val, second_val), std::max(first_val, second_val)); std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); // Create, setup, run, and destroy Convert operator. ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); xnn_operator_t convert_op = nullptr; - ASSERT_EQ(xnn_status_success, - xnn_create_convert_nc_f32_qd8( - 0, &convert_op)); + ASSERT_EQ(xnn_status_success, xnn_create_convert_nc_f32_qd8( + /*flags=*/0, &convert_op)); ASSERT_NE(nullptr, convert_op); // Smart pointer to automatically delete convert op. - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + std::unique_ptr + auto_convert_op(convert_op, xnn_delete_operator); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qd8(convert_op, batch_size(), - channels(), input_stride(), output_stride(), /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8(convert_op, input.data(), output.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qd8( + convert_op, batch_size(), channels(), input_stride(), + output_stride(), /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8( + convert_op, input.data(), output.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); // Verify results. for (size_t i = 0; i < batch_size(); i++) { const float* input_ptr = &input[i * input_stride()]; - const auto minmax = std::minmax_element(input_ptr, input_ptr + channels()); + const auto minmax = + std::minmax_element(input_ptr, input_ptr + channels()); const float rmin = math_min_f32(0.0f, *minmax.first); const float rmax = math_max_f32(0.0f, *minmax.second); - const float max_acceptable_error = 0.5001f * (rmax - rmin) / std::numeric_limits::max(); + const float max_acceptable_error = + 0.5001f * (rmax - rmin) / std::numeric_limits::max(); for (size_t c = 0; c < channels(); c++) { float expected = input[i * input_stride() + c]; int8_t quantized_val = output[i * output_stride() + c]; - float dequantized_val = float(quantized_val - quantization_params[i].zero_point) * quantization_params[i].scale; + float dequantized_val = + static_cast(static_cast(quantized_val) - + quantization_params[i].zero_point) * + quantization_params[i].scale; + EXPECT_NEAR(expected, dequantized_val, max_acceptable_error) + << "at batch " << i << " / " << batch_size() << ", channel " << c + << " / " << channels() << " scale " + << quantization_params[i].scale << " zp " + << quantization_params[i].zero_point << " int " + << (int)quantized_val; + } + } + } + } + + void TestF32toQDU8() const { + xnnpack::ReplicableRandomDevice rng; + + xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(float) + + (batch_size() - 1) * input_stride() + + channels()); + xnnpack::Buffer output((batch_size() - 1) * output_stride() + + channels()); + xnnpack::Buffer quantization_params( + batch_size() + XNN_EXTRA_QUANTIZATION_PARAMS); + // std::uniform_real_distribution range_dist(-100000, 100000); + // std::uniform_real_distribution range_dist(-1, 1); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + // const float first_val = range_dist(rng); + // const float second_val = range_dist(rng); + // std::uniform_real_distribution f32dist(std::min(first_val, + // second_val), std::max(first_val, second_val)); + std::uniform_real_distribution f32dist(-1.f, 1.f); + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + + // Create, setup, run, and destroy Convert operator. + ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); + xnn_operator_t convert_op = nullptr; + + ASSERT_EQ(xnn_status_success, xnn_create_convert_nc_f32_qdu8( + /*flags=*/0, &convert_op)); + ASSERT_NE(nullptr, convert_op); + + // Smart pointer to automatically delete convert op. + std::unique_ptr + auto_convert_op(convert_op, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qdu8( + convert_op, batch_size(), channels(), input_stride(), + output_stride(), /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qdu8( + convert_op, input.data(), output.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + + // Verify results. + for (size_t i = 0; i < batch_size(); i++) { + const float* input_ptr = &input[i * input_stride()]; + const auto minmax = + std::minmax_element(input_ptr, input_ptr + channels()); + const float rmin = math_min_f32(0.0f, *minmax.first); + const float rmax = math_max_f32(0.0f, *minmax.second); + const float max_acceptable_error = + 0.5001f * (rmax - rmin) / std::numeric_limits::max(); + for (size_t c = 0; c < channels(); c++) { + float expected = input[i * input_stride() + c]; + uint8_t quantized_val = output[i * output_stride() + c]; + float dequantized_val = + static_cast(quantized_val - + quantization_params[i].zero_point) * + quantization_params[i].scale; EXPECT_NEAR(expected, dequantized_val, max_acceptable_error) - << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); + << "at batch " << i << " / " << batch_size() << ", channel " << c + << " / " << channels(); } } } @@ -250,7 +320,8 @@ class ConvertOperatorTester { const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_nr2_config(); xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(float) + - (batch_size() - 1) * input_stride() + channels()); + (batch_size() - 1) * input_stride() + + channels()); xnnpack::Buffer output(xnn_x8_packq_f32qp8_packed_size( batch_size(), channels(), gemm_config->mr, 1 << gemm_config->log2_kr, 1 << gemm_config->log2_sr)); @@ -267,7 +338,8 @@ class ConvertOperatorTester { xnn_operator_t convert_op = nullptr; ASSERT_EQ(xnn_status_success, - xnn_create_convert_nc_f32_qp8(0, &convert_op)); + xnn_create_convert_nc_f32_qp8( + 0, xnn_init_qp8_f32_qc4w_gemm_config(), &convert_op)); ASSERT_NE(nullptr, convert_op); // Smart pointer to automatically delete convert op. @@ -275,8 +347,9 @@ class ConvertOperatorTester { auto_convert_op(convert_op, xnn_delete_operator); ASSERT_EQ(xnn_status_success, - xnn_reshape_convert_nc_f32_qp8(convert_op, batch_size(), - channels(), input_stride(), + xnn_reshape_convert_nc_f32_qp8(convert_op, /*num_groups=*/1, + batch_size(), channels(), + input_stride(), /*threadpool=*/nullptr)); ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qp8(convert_op, input.data(), @@ -450,3 +523,57 @@ TEST(CONVERT_NC_F32_QP8, small_batch_with_input_stride) { .TestF32toQD8(); } } + +TEST(CONVERT_NC_F32_QDU8, unit_batch) { + for (size_t channels = 1; channels < 100; channels++) { + ConvertOperatorTester() + .batch_size(1) + .channels(channels) + .iterations(3) + .TestF32toQDU8(); + } +} + +TEST(CONVERT_NC_F32_QDU8, small_batch) { + for (size_t channels = 1; channels < 100; channels++) { + ConvertOperatorTester() + .batch_size(3) + .channels(channels) + .iterations(3) + .TestF32toQDU8(); + } +} + +TEST(CONVERT_NC_F32_QDU8, small_batch_with_input_stride) { + for (size_t channels = 10; channels < 11; channels += 15) { + ConvertOperatorTester() + .batch_size(3) + .channels(channels) + .input_stride(129) + .iterations(3) + .TestF32toQDU8(); + } +} + +TEST(CONVERT_NC_F32_QDU8, small_batch_with_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + ConvertOperatorTester() + .batch_size(3) + .channels(channels) + .output_stride(117) + .iterations(3) + .TestF32toQDU8(); + } +} + +TEST(CONVERT_NC_F32_QDU8, small_batch_with_input_and_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + ConvertOperatorTester() + .batch_size(3) + .channels(channels) + .input_stride(129) + .output_stride(117) + .iterations(3) + .TestF32toQDU8(); + } +} diff --git a/test/convolution-2d.cc b/test/convolution-2d.cc index 0c7601056fdc..e8511f253ffe 100644 --- a/test/convolution-2d.cc +++ b/test/convolution-2d.cc @@ -25,6 +25,7 @@ #include "xnnpack/subgraph.h" #include "convolution-test-helpers.h" #include "replicable_random_device.h" +#include "runtime-flags.h" namespace xnnpack { @@ -382,7 +383,7 @@ TEST_F(ConvolutionTestQD8F16QC8W, internally_allocated_dynamic_quantization_para kernel_width, subsampling_height, subsampling_width, dilation_height, dilation_width, groups, group_input_channels, group_output_channels, output_min, output_max, dq_quantized_id, kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -454,6 +455,7 @@ TEST_F(ConvolutionTestQD8F32QC8W, define) TEST_F(ConvolutionTestQD8F32QC8W, internally_allocated_dynamic_quantization_parameters) { + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); @@ -553,7 +555,7 @@ TEST_F(ConvolutionTestQD8F32QC8W, internally_allocated_dynamic_quantization_para kernel_width, subsampling_height, subsampling_width, dilation_height, dilation_width, groups, group_input_channels, group_output_channels, output_min, output_max, dq_quantized_id, kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -978,7 +980,7 @@ TEST_F(ConvolutionTestQC8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1112,7 +1114,7 @@ TEST_F(ConvolutionTestQS8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1249,7 +1251,7 @@ TEST_F(ConvolutionTestQU8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1341,7 +1343,7 @@ TEST_F(ConvolutionTestF16, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1433,7 +1435,7 @@ TEST_F(ConvolutionTestF32, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1527,7 +1529,7 @@ TEST_F(ConvolutionTestF32, transient_indirection_buffer) /*flags=*/XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1587,7 +1589,7 @@ TEST_F(ConvolutionTestF32, reshape_output) std::generate(filter.begin(), filter.end(), [&]() { return f32dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/convolution-operator-tester.h b/test/convolution-operator-tester.h index c7ccd248c4af..8f6b6434abd1 100644 --- a/test/convolution-operator-tester.h +++ b/test/convolution-operator-tester.h @@ -25,6 +25,7 @@ #include "xnnpack/buffer.h" #include "xnnpack/cache.h" #include "xnnpack/common.h" +#include "xnnpack/config.h" #include "xnnpack/math.h" #include "xnnpack/microparams.h" #include "convolution-test-helpers.h" diff --git a/test/copy.cc b/test/copy.cc index 87f65a5e006b..aa357dae3254 100644 --- a/test/copy.cc +++ b/test/copy.cc @@ -17,6 +17,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" using CopyTestQS8 = UnaryTest; using CopyTestQU8 = UnaryTest; @@ -208,7 +209,7 @@ TEST_F(CopyTestQS8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_copy(subgraph, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -265,7 +266,7 @@ TEST_F(CopyTestQU8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_copy(subgraph, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -319,7 +320,7 @@ TEST_F(CopyTestF16, matches_operator_api) xnn_status_success, xnn_define_copy(subgraph, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -372,7 +373,7 @@ TEST_F(CopyTestF32, matches_operator_api) xnn_status_success, xnn_define_copy(subgraph, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/deconvolution-2d.cc b/test/deconvolution-2d.cc index c973d2255bbe..ed06f9a55721 100644 --- a/test/deconvolution-2d.cc +++ b/test/deconvolution-2d.cc @@ -16,6 +16,7 @@ #include #include "xnnpack.h" #include "xnnpack/buffer.h" +#include "xnnpack/internal.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" #include "xnnpack/operator-utils.h" @@ -23,6 +24,7 @@ #include "xnnpack/requantization.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class DeconvolutionTestBase : public ::testing::Test { protected: @@ -641,7 +643,7 @@ TEST_F(DeconvolutionTestQS8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -783,7 +785,7 @@ TEST_F(DeconvolutionTestQU8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -871,7 +873,7 @@ TEST_F(DeconvolutionTestF16, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -908,7 +910,7 @@ TEST_F(DeconvolutionTestQD8F32QC8W, internally_allocated_dynamic_quantization_pa xnn_operator_t convert_op = nullptr; const size_t quantized_batch_size = input_height * input_width * group_input_channels * groups; xnn_status status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &convert_op); + /*flags=*/0, &convert_op); std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -994,7 +996,7 @@ TEST_F(DeconvolutionTestQD8F32QC8W, internally_allocated_dynamic_quantization_pa /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1082,7 +1084,7 @@ TEST_F(DeconvolutionTestF32, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1139,7 +1141,7 @@ TEST_F(DeconvolutionTestF32, reshape_output) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/depth-to-space-2d.cc b/test/depth-to-space-2d.cc index 95d0e8875bcf..34a07410cb67 100644 --- a/test/depth-to-space-2d.cc +++ b/test/depth-to-space-2d.cc @@ -24,6 +24,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class DepthToSpaceTest : public ::testing::Test { protected: @@ -299,7 +300,7 @@ TEST_F(DepthToSpaceTestQS8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_depth_to_space_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -363,7 +364,7 @@ TEST_F(DepthToSpaceTestQU8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_depth_to_space_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -424,7 +425,7 @@ TEST_F(DepthToSpaceTestF16, matches_operator_api) xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_define_depth_to_space_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -484,7 +485,7 @@ TEST_F(DepthToSpaceTestF32, matches_operator_api) xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_define_depth_to_space_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -520,7 +521,7 @@ TEST_F(DepthToSpaceTestF32, reshape_output) xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_define_depth_to_space_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/depthwise-convolution-2d.cc b/test/depthwise-convolution-2d.cc index f4f08c3df18c..22b8e2f20034 100644 --- a/test/depthwise-convolution-2d.cc +++ b/test/depthwise-convolution-2d.cc @@ -26,6 +26,7 @@ #include "xnnpack/subgraph.h" #include "convolution-test-helpers.h" #include "replicable_random_device.h" +#include "runtime-flags.h" namespace xnnpack { @@ -658,7 +659,7 @@ TEST_F(DepthwiseConvolutionTestQC8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -794,7 +795,7 @@ TEST_F(DepthwiseConvolutionTestQS8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -932,7 +933,7 @@ TEST_F(DepthwiseConvolutionTestQU8, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1023,7 +1024,7 @@ TEST_F(DepthwiseConvolutionTestF16, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1115,7 +1116,7 @@ TEST_F(DepthwiseConvolutionTestF32, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1130,6 +1131,10 @@ TEST_F(DepthwiseConvolutionTestF32, reshape_output) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + std::generate(filter.begin(), filter.end(), [&]() { return f32dist(rng); }); + std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); + // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); @@ -1169,7 +1174,7 @@ TEST_F(DepthwiseConvolutionTestF32, reshape_output) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -1278,7 +1283,7 @@ TEST_F(DepthwiseConvolutionTestF32, transient_indirection_buffer) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/even-split2.cc b/test/even-split2.cc index bfde24b732c4..f5012a6b6ccc 100644 --- a/test/even-split2.cc +++ b/test/even-split2.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class EvenSplit2Test : public ::testing::Test { protected: @@ -354,7 +355,7 @@ TEST_F(EvenSplit2TestQS8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_even_split2(subgraph, axis, input_id, output1_id, output2_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -429,7 +430,7 @@ TEST_F(EvenSplit2TestQU8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_even_split2(subgraph, axis, input_id, output1_id, output2_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -501,7 +502,7 @@ TEST_F(EvenSplit2TestF16, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_even_split2(subgraph, axis, input_id, output1_id, output2_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -573,7 +574,7 @@ TEST_F(EvenSplit2TestF32, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_even_split2(subgraph, axis, input_id, output1_id, output2_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -623,7 +624,7 @@ TEST_F(EvenSplit2TestF32, reshape_output) xnn_define_even_split2(subgraph, axis, input_id, output1_id, output2_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/even-split3.cc b/test/even-split3.cc index 5a35f520d85f..05d5c735c752 100644 --- a/test/even-split3.cc +++ b/test/even-split3.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class EvenSplit3Test : public ::testing::Test { protected: @@ -423,7 +424,7 @@ TEST_F(EvenSplit3TestQS8, matches_operator_api) xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -518,7 +519,7 @@ TEST_F(EvenSplit3TestQU8, matches_operator_api) xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -609,7 +610,7 @@ TEST_F(EvenSplit3TestF16, matches_operator_api) xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -700,7 +701,7 @@ TEST_F(EvenSplit3TestF32, matches_operator_api) xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -759,7 +760,7 @@ TEST_F(EvenSplit3TestF32, reshape_output) xnn_define_even_split3(subgraph, axis, input_id, output1_id, output2_id, output3_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/even-split4.cc b/test/even-split4.cc index 72fd0b3ae7ef..aa405d87924d 100644 --- a/test/even-split4.cc +++ b/test/even-split4.cc @@ -23,6 +23,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class EvenSplit4Test : public ::testing::Test { protected: @@ -480,7 +481,7 @@ TEST_F(EvenSplit4TestQS8, matches_operator_api) xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -593,7 +594,7 @@ TEST_F(EvenSplit4TestQU8, matches_operator_api) xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -701,7 +702,7 @@ TEST_F(EvenSplit4TestF16, matches_operator_api) xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -809,7 +810,7 @@ TEST_F(EvenSplit4TestF32, matches_operator_api) xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -877,7 +878,7 @@ TEST_F(EvenSplit4TestF32, reshape_output) xnn_define_even_split4(subgraph, axis, input_id, output1_id, output2_id, output3_id, output4_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/f16-f32acc-gemm-minmax.cc b/test/f16-f32acc-gemm-minmax.cc index 08001b286462..c69852762584 100644 --- a/test/f16-f32acc-gemm-minmax.cc +++ b/test/f16-f32acc-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_1x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -322,6 +302,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_1x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -341,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_3x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -360,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_4x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -379,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_4x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -398,6 +382,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_5x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -417,6 +402,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_5x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -436,6 +422,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_6x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -455,6 +442,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_gemm_minmax_ukernel_7x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, diff --git a/test/f16-f32acc-igemm-minmax.cc b/test/f16-f32acc-igemm-minmax.cc index 02d96ce92e4d..f3f3063b1fa1 100644 --- a/test/f16-f32acc-igemm-minmax.cc +++ b/test/f16-f32acc-igemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_1x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -322,6 +302,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_1x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -341,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_3x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -360,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_4x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -379,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_4x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -398,6 +382,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_5x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -417,6 +402,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_5x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -436,6 +422,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_6x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -455,6 +442,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_f32acc_igemm_minmax_ukernel_7x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, diff --git a/test/f16-gemm-minmax.cc b/test/f16-gemm-minmax.cc index 7c66a7149214..de04e559eb05 100644 --- a/test/f16-gemm-minmax.cc +++ b/test/f16-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x8__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x8__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x8__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_8x8__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x16__asm_aarch64_neonfp16arith_ld32, xnn_init_f16_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x16__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -427,6 +412,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x16__asm_aarch64_neonfp16arith_ld32, xnn_init_f16_minmax_scalar_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x16__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -465,6 +452,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a55, xnn_init_f16_minmax_scalar_params, @@ -484,6 +472,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a55r0, xnn_init_f16_minmax_scalar_params, @@ -503,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a75, xnn_init_f16_minmax_scalar_params, @@ -522,6 +512,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_ld32, xnn_init_f16_minmax_scalar_params, @@ -541,6 +532,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -563,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -582,6 +575,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -601,6 +595,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -620,6 +615,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_8x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -639,6 +635,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -658,6 +655,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -677,6 +675,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -696,6 +695,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -718,6 +718,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -737,6 +738,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -756,6 +758,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_5x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -775,6 +778,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -794,6 +798,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_7x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -813,6 +818,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_8x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -832,6 +838,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -851,6 +858,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -870,6 +878,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_5x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -889,6 +898,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -908,6 +918,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_7x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -927,6 +938,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_8x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -949,6 +961,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -968,6 +981,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -987,6 +1001,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_5x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1006,6 +1021,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_6x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1025,6 +1041,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_7x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1044,6 +1061,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_1x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1063,6 +1081,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_3x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1082,6 +1101,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1101,6 +1121,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_gemm_minmax_ukernel_5x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, diff --git a/test/f16-igemm-minmax.cc b/test/f16-igemm-minmax.cc index 7eeba53dcee1..8398733f0a5c 100644 --- a/test/f16-igemm-minmax.cc +++ b/test/f16-igemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x16__asm_aarch64_neonfp16arith_ld32, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x16__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x16__asm_aarch64_neonfp16arith_ld32, xnn_init_f16_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x16__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a55, xnn_init_f16_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a55r0, xnn_init_f16_minmax_scalar_params, @@ -427,6 +412,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_cortex_a75, xnn_init_f16_minmax_scalar_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_ld32, xnn_init_f16_minmax_scalar_params, @@ -465,6 +452,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x16__asm_aarch64_neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -487,6 +475,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -506,6 +495,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -525,6 +515,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -544,6 +535,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -563,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -582,6 +575,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -601,6 +595,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_8x8__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -620,6 +615,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_8x16__neonfp16arith_ld64, xnn_init_f16_minmax_scalar_params, @@ -642,6 +638,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -661,6 +658,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -680,6 +678,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_3x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -699,6 +698,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -718,6 +718,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -737,6 +738,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_5x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -756,6 +758,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_5x16__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -775,6 +778,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -794,6 +798,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_7x8__avx2_broadcast, xnn_init_f16_minmax_scalar_params, @@ -816,6 +821,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -835,6 +841,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -854,6 +861,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_5x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -873,6 +881,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -892,6 +901,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_7x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -911,6 +921,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_8x32__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -930,6 +941,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_1x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -949,6 +961,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_4x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -968,6 +981,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_5x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -987,6 +1001,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_6x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1006,6 +1021,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_7x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, @@ -1025,6 +1041,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f16_igemm_minmax_ukernel_8x64__avx512fp16_broadcast, xnn_init_f16_minmax_scalar_params, diff --git a/test/f32-gemm-2.cc b/test/f32-gemm-2.cc index ef2e5e9801f7..4584bcd09d88 100644 --- a/test/f32-gemm-2.cc +++ b/test/f32-gemm-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_5x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_6x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_6x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -675,6 +640,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -690,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -780,6 +752,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -795,6 +768,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -812,6 +786,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x4__scalar, xnn_pack_f32_gemm_goi_w); @@ -828,6 +803,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x2__scalar, xnn_pack_f32_gemm_goi_w); @@ -844,6 +820,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x4__scalar, xnn_pack_f32_gemm_goi_w); diff --git a/test/f32-gemm-goi-minmax.cc b/test/f32-gemm-goi-minmax.cc index 4d3674ae7be1..9c7f0f0643a7 100644 --- a/test/f32-gemm-goi-minmax.cc +++ b/test/f32-gemm-goi-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_goi_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params); @@ -321,6 +301,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_goi_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_prfm, xnn_init_f32_minmax_scalar_params); @@ -339,6 +320,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_goi_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params); diff --git a/test/f32-gemm-minmax-2.cc b/test/f32-gemm-minmax-2.cc index 8351af37b214..962a5614f859 100644 --- a/test/f32-gemm-minmax-2.cc +++ b/test/f32-gemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -306,6 +285,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -314,7 +294,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -324,12 +304,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -457,14 +431,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -490,14 +456,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -598,6 +556,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -607,7 +566,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -617,12 +576,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -730,14 +683,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -763,14 +708,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -878,6 +815,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -897,6 +835,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x4__asm_aarch32_vfp_ld64, xnn_init_f32_minmax_scalar_params, @@ -913,6 +852,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -932,6 +872,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_ld64, xnn_init_f32_minmax_scalar_params, @@ -954,6 +895,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neon_ld128_acc2_prfm, xnn_init_f32_minmax_scalar_params, @@ -973,6 +915,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -992,6 +935,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1011,6 +955,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1030,6 +975,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc2, xnn_init_f32_minmax_scalar_params, @@ -1049,6 +995,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc4_prfm, xnn_init_f32_minmax_scalar_params, @@ -1068,6 +1015,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2_prfm, xnn_init_f32_minmax_scalar_params, @@ -1087,6 +1035,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4, xnn_init_f32_minmax_scalar_params, @@ -1106,6 +1055,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_prfm, xnn_init_f32_minmax_scalar_params, @@ -1125,6 +1075,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/1, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x1__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1144,6 +1095,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1163,6 +1115,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -1182,6 +1135,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -1201,6 +1155,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1220,6 +1175,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1232,6 +1188,166 @@ std::vector CreateTests2( return info.param.test_name; }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X16__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X8__ASM_AARCH64_NEONFMA_LD64, GemmTest, testing::ValuesIn(CreateTests1( @@ -1239,6 +1355,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1258,6 +1375,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1277,6 +1395,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1296,6 +1415,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73, xnn_init_f32_minmax_scalar_params, @@ -1315,6 +1435,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1337,6 +1458,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_2x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1356,6 +1478,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1375,6 +1498,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1394,6 +1518,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1416,6 +1541,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1435,6 +1561,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1454,6 +1581,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1476,6 +1604,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1498,6 +1627,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1520,6 +1650,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1539,6 +1670,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1558,6 +1690,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1577,6 +1710,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__neonfma_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1596,6 +1730,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1615,6 +1750,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__neon_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1634,6 +1770,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1653,6 +1790,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1672,6 +1810,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__neonfma_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1691,6 +1830,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1710,6 +1850,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1732,6 +1873,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1751,6 +1893,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1770,6 +1913,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1789,6 +1933,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2c4__sse, xnn_init_f32_minmax_scalar_params, @@ -1808,6 +1953,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1827,6 +1973,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1846,6 +1993,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1865,6 +2013,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1884,6 +2033,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1903,6 +2053,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1922,6 +2073,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1941,6 +2093,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1960,6 +2113,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1979,6 +2133,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1998,6 +2153,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2009,7 +2165,293 @@ std::vector CreateTests2( [](const testing::TestParamInfo& info) { return info.param.test_name; }); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_7X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_8X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_11X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/11, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_7X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_8X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_11X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/11, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X16__FMA3_BROADCAST, GemmTest, testing::ValuesIn(CreateTests1( @@ -2017,6 +2459,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2036,6 +2479,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2055,6 +2499,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2074,6 +2519,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2093,6 +2539,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2112,6 +2559,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2131,6 +2579,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2150,6 +2599,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2169,6 +2619,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2191,6 +2642,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2210,6 +2662,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2229,6 +2682,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/12, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_12x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2248,6 +2702,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/13, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_13x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2267,6 +2722,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/14, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_14x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2286,6 +2742,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/15, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_15x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2305,6 +2762,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2324,6 +2782,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2343,6 +2802,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/12, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_12x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2362,6 +2822,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/13, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_13x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2381,6 +2842,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/14, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_14x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2400,6 +2862,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/15, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_15x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2419,6 +2882,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2438,6 +2902,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2457,6 +2922,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/12, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_12x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2476,6 +2942,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/13, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_13x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2495,6 +2962,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/14, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_14x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2514,6 +2982,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/15, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_15x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2536,6 +3005,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2552,6 +3022,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2568,6 +3039,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2584,6 +3056,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2600,6 +3073,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2616,6 +3090,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2632,6 +3107,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2648,6 +3124,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2664,6 +3141,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2680,6 +3158,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2696,6 +3175,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2712,6 +3192,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2728,6 +3209,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2744,6 +3226,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2760,6 +3243,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2776,6 +3260,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2792,6 +3277,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2811,6 +3297,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2827,6 +3314,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2843,6 +3331,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -2859,6 +3348,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2875,6 +3365,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2891,6 +3382,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2907,6 +3399,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2923,6 +3416,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2939,6 +3433,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2c4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2955,6 +3450,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2971,6 +3467,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2987,6 +3484,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3003,6 +3501,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3019,6 +3518,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3035,6 +3535,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3051,6 +3552,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3067,6 +3569,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3083,6 +3586,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3099,6 +3603,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3115,6 +3620,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3131,6 +3637,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3147,6 +3654,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3163,6 +3671,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3182,6 +3691,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -3200,6 +3710,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -3218,6 +3729,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x32__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3237,6 +3749,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x64__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3256,6 +3769,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x64__hvx_broadcast, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-gemm-minmax.cc b/test/f32-gemm-minmax.cc index c3e5cfeb3f24..b62809225bf6 100644 --- a/test/f32-gemm-minmax.cc +++ b/test/f32-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -306,6 +285,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -314,7 +294,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -324,12 +304,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -457,14 +431,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -490,14 +456,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -598,6 +556,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -607,7 +566,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -617,12 +576,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -730,14 +683,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -763,14 +708,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -878,6 +815,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -897,6 +835,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a7, xnn_init_f32_minmax_scalar_params, @@ -916,6 +855,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -935,6 +875,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -954,6 +895,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -973,6 +915,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -995,6 +938,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neon_ld128_acc2, xnn_init_f32_minmax_scalar_params, @@ -1014,6 +958,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1033,6 +978,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1052,6 +998,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc2_prfm, xnn_init_f32_minmax_scalar_params, @@ -1071,6 +1018,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc4, xnn_init_f32_minmax_scalar_params, @@ -1090,6 +1038,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_prfm, xnn_init_f32_minmax_scalar_params, @@ -1109,6 +1058,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1128,6 +1078,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2, xnn_init_f32_minmax_scalar_params, @@ -1147,6 +1098,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4_prfm, xnn_init_f32_minmax_scalar_params, @@ -1166,6 +1118,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/12, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x12__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1185,6 +1138,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/1, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x1__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1204,6 +1158,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1223,6 +1178,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1242,6 +1198,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1261,6 +1218,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1273,6 +1231,46 @@ std::vector CreateTests2( return info.param.test_name; }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X8__ASM_AARCH64_NEONFMA_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_FMA; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X8__ASM_AARCH64_NEONFMA_LD128, GemmTest, testing::ValuesIn(CreateTests1( @@ -1280,6 +1278,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1299,6 +1298,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/12, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x12__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1318,6 +1318,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1337,6 +1338,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -1356,6 +1358,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -1375,6 +1378,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1394,6 +1398,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1413,6 +1418,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1435,6 +1441,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1454,6 +1461,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1476,6 +1484,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_2x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1495,6 +1504,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1514,6 +1524,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1536,6 +1547,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1555,6 +1567,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1577,6 +1590,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1596,6 +1610,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1615,6 +1630,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1637,6 +1653,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1656,6 +1673,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1675,6 +1693,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1694,6 +1713,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1716,6 +1736,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1735,6 +1756,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1757,6 +1779,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__neon_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1776,6 +1799,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1795,6 +1819,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1814,6 +1839,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1836,6 +1862,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1858,6 +1885,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1880,6 +1908,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x2__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1902,6 +1931,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x2__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1924,6 +1954,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1946,6 +1977,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1965,6 +1997,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1984,6 +2017,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -2006,6 +2040,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -2025,6 +2060,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -2044,6 +2080,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -2063,6 +2100,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -2082,6 +2120,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -2101,6 +2140,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -2120,6 +2160,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x2c4__sse, xnn_init_f32_minmax_scalar_params, @@ -2139,6 +2180,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -2158,6 +2200,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -2177,6 +2220,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2196,6 +2240,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2215,6 +2260,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2234,6 +2280,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2253,6 +2300,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2272,6 +2320,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2291,6 +2340,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2310,6 +2360,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2321,7 +2372,273 @@ std::vector CreateTests2( [](const testing::TestParamInfo& info) { return info.param.test_name; }); +#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 + + +#if XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_6X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_9X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/9, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_10X16__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/10, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_6X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_9X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/9, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_10X32__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/10, /*nr=*/32, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X64__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/5, /*nr=*/64, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ARCH_X86 || XNN_ARCH_X86_64 INSTANTIATE_TEST_SUITE_P( F32_GEMM_MINMAX_4X8__FMA3_BROADCAST, GemmTest, testing::ValuesIn(CreateTests1( @@ -2329,6 +2646,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2348,6 +2666,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2370,6 +2689,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2389,6 +2709,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2408,6 +2729,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2427,6 +2749,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2446,6 +2769,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/9, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_9x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2465,6 +2789,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/10, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_10x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2484,6 +2809,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/11, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_11x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2503,6 +2829,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/16, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_16x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2522,6 +2849,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2541,6 +2869,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2560,6 +2889,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2579,6 +2909,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2598,6 +2929,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/9, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_9x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2617,6 +2949,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/10, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_10x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2636,6 +2969,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/11, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_11x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2655,6 +2989,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/16, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_16x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2674,6 +3009,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2693,6 +3029,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2712,6 +3049,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2731,6 +3069,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2750,6 +3089,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/9, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_9x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2769,6 +3109,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/10, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_10x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2788,6 +3129,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/11, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_11x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2807,6 +3149,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/16, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_16x64__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2829,6 +3172,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2845,6 +3189,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2861,6 +3206,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2877,6 +3223,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2893,6 +3240,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2909,6 +3257,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2925,6 +3274,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2c4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2941,6 +3291,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2c4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2957,6 +3308,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2973,6 +3325,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2989,6 +3342,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -3005,6 +3359,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3021,6 +3376,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -3037,6 +3393,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -3053,6 +3410,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -3072,6 +3430,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3088,6 +3447,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3104,6 +3464,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3120,6 +3481,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3136,6 +3498,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3152,6 +3515,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3168,6 +3532,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3184,6 +3549,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3200,6 +3566,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3219,6 +3586,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_2x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -3235,6 +3603,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -3251,6 +3620,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -3269,6 +3639,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -3286,6 +3657,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -3303,6 +3675,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -3321,6 +3694,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -3340,6 +3714,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -3362,6 +3737,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/128, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_1x128__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3381,6 +3757,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/128, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_2x128__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3400,6 +3777,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_7x64__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3419,6 +3797,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_8x32__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3438,6 +3817,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/16, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_minmax_ukernel_16x32__hvx_broadcast, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-gemm-minmax.yaml b/test/f32-gemm-minmax.yaml index cd15dfee194c..bd67b87c1ed2 100644 --- a/test/f32-gemm-minmax.yaml +++ b/test/f32-gemm-minmax.yaml @@ -180,6 +180,48 @@ pack: xnn_pack_f32_gemm_goi_w k-block: 8 pipelined: true + +- name: xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 + - name: xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w @@ -570,6 +612,114 @@ init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w k-block: 4 +- name: xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 +- name: xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 1 - name: xnn_f32_gemm_minmax_ukernel_4x8__fma3_broadcast init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w diff --git a/test/f32-gemm-relu-2.cc b/test/f32-gemm-relu-2.cc index bc92af5684b9..775cde7d1f37 100644 --- a/test/f32-gemm-relu-2.cc +++ b/test/f32-gemm-relu-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_3x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_3x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_3x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x2c4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_5x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -672,6 +637,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_5x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -687,6 +653,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_6x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -780,6 +752,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -795,6 +768,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -810,6 +784,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -825,6 +800,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -843,6 +819,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_2x4__wasm, xnn_pack_f32_gemm_goi_w); @@ -858,6 +835,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x2__wasm, xnn_pack_f32_gemm_goi_w); @@ -873,6 +851,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x4__wasm, xnn_pack_f32_gemm_goi_w); @@ -890,6 +869,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_2x4__scalar, xnn_pack_f32_gemm_goi_w); @@ -907,6 +887,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x4v__rvv, xnn_pack_f32_gemm_goi_w); @@ -925,6 +906,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_7x4v__rvv, xnn_pack_f32_gemm_goi_w); diff --git a/test/f32-gemm-relu.cc b/test/f32-gemm-relu.cc index 8385b587ace0..be7ed73f03d7 100644 --- a/test/f32-gemm-relu.cc +++ b/test/f32-gemm-relu.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_5x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_6x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_6x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -675,6 +640,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -690,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_gemm_goi_w); @@ -783,6 +755,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x4__wasm, xnn_pack_f32_gemm_goi_w); @@ -800,6 +773,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_1x4__scalar, xnn_pack_f32_gemm_goi_w); @@ -816,6 +790,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x2__scalar, xnn_pack_f32_gemm_goi_w); @@ -832,6 +807,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_relu_ukernel_4x4__scalar, xnn_pack_f32_gemm_goi_w); diff --git a/test/f32-gemm.cc b/test/f32-gemm.cc index 4bb856df919a..b097f16cfc8b 100644 --- a/test/f32-gemm.cc +++ b/test/f32-gemm.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x4__asm_aarch32_vfp_ld64, xnn_pack_f32_gemm_goi_w); @@ -585,6 +544,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_3x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -600,6 +560,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_3x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -615,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_3x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -630,6 +592,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x2c4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -645,6 +608,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -660,6 +624,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -675,6 +640,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_5x8__wasmsimd_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -690,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_5x8s4__wasmsimd, xnn_pack_f32_gemm_goi_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_6x8__wasmsimd_splat, xnn_pack_f32_gemm_goi_w); @@ -723,6 +691,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -738,6 +707,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -753,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -768,6 +739,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -783,6 +755,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_gemm_goi_w); @@ -798,6 +771,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -813,6 +787,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_gemm_goi_w); @@ -830,6 +805,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_2x4__scalar, xnn_pack_f32_gemm_goi_w); @@ -847,6 +823,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_1x4v__rvv, xnn_pack_f32_gemm_goi_w); @@ -865,6 +842,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemm_ukernel_7x4v__rvv, xnn_pack_f32_gemm_goi_w); diff --git a/test/f32-gemminc-minmax-2.cc b/test/f32-gemminc-minmax-2.cc index 6aa13fd08d58..85198c8aeffb 100644 --- a/test/f32-gemminc-minmax-2.cc +++ b/test/f32-gemminc-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -186,14 +181,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,14 +206,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -326,6 +305,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -334,7 +314,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -344,12 +324,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -457,14 +431,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -490,14 +456,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -604,6 +562,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -623,6 +582,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -642,6 +602,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/12, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x12__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -661,6 +622,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -680,6 +642,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -699,6 +662,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -718,6 +682,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -737,6 +702,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/12, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x12__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -756,6 +722,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -775,6 +742,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -794,6 +762,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -813,6 +782,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -835,6 +805,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -857,6 +828,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -876,6 +848,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -895,6 +868,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -917,6 +891,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -939,6 +914,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -958,6 +934,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__neon_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -977,6 +954,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -996,6 +974,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1018,6 +997,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1040,6 +1020,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1062,6 +1043,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1084,6 +1066,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1103,6 +1086,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1122,6 +1106,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1141,6 +1126,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1160,6 +1146,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_8x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1179,6 +1166,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_8x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1201,6 +1189,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1220,6 +1209,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1239,6 +1229,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1258,6 +1249,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1277,6 +1269,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1296,6 +1289,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1315,6 +1309,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1334,6 +1329,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1353,6 +1349,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1372,6 +1369,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1391,6 +1389,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_7x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1410,6 +1409,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1429,6 +1429,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1448,6 +1449,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1467,6 +1469,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1486,6 +1489,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1508,6 +1512,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1527,6 +1532,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1546,6 +1552,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_7x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1565,6 +1572,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_8x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1587,6 +1595,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1603,6 +1612,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -1619,6 +1629,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -1635,6 +1646,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1651,6 +1663,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -1667,6 +1680,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -1683,6 +1697,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1699,6 +1714,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1715,6 +1731,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1731,6 +1748,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -1747,6 +1765,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -1763,6 +1782,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -1779,6 +1799,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -1798,6 +1819,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1814,6 +1836,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1830,6 +1853,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -1846,6 +1870,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -1862,6 +1887,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1878,6 +1904,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -1894,6 +1921,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -1910,6 +1938,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1926,6 +1955,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1945,6 +1975,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_2x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1961,6 +1992,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1979,6 +2011,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-gemminc-minmax.cc b/test/f32-gemminc-minmax.cc index 70efd63aec47..907e79dbaf77 100644 --- a/test/f32-gemminc-minmax.cc +++ b/test/f32-gemminc-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -186,14 +181,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,14 +206,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -326,6 +305,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -334,7 +314,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -344,12 +324,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -457,14 +431,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -490,14 +456,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -604,6 +562,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -623,6 +582,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -642,6 +602,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -661,6 +622,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -680,6 +642,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -699,6 +662,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -718,6 +682,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73, xnn_init_f32_minmax_scalar_params, @@ -737,6 +702,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -759,6 +725,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -778,6 +745,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -800,6 +768,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -822,6 +791,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -841,6 +811,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__neonfma_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -860,6 +831,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -879,6 +851,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -901,6 +874,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -923,6 +897,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -942,6 +917,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__neon_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -961,6 +937,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -980,6 +957,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__neonfma_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1002,6 +980,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1021,6 +1000,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1040,6 +1020,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1059,6 +1040,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1078,6 +1060,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1097,6 +1080,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1116,6 +1100,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1135,6 +1120,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1154,6 +1140,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1173,6 +1160,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1192,6 +1180,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1211,6 +1200,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1230,6 +1220,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1249,6 +1240,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1268,6 +1260,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1287,6 +1280,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1306,6 +1300,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1325,6 +1320,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_7x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1344,6 +1340,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_8x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1366,6 +1363,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1385,6 +1383,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1407,6 +1406,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1423,6 +1423,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -1439,6 +1440,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -1455,6 +1457,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -1471,6 +1474,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1487,6 +1491,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -1503,6 +1508,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -1519,6 +1525,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1535,6 +1542,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -1551,6 +1559,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -1567,6 +1576,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -1583,6 +1593,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1599,6 +1610,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -1615,6 +1627,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -1631,6 +1644,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -1647,6 +1661,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -1663,6 +1678,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1682,6 +1698,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1698,6 +1715,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -1714,6 +1732,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -1730,6 +1749,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -1746,6 +1766,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -1762,6 +1783,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1778,6 +1800,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -1794,6 +1817,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_3x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -1810,6 +1834,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1826,6 +1851,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -1842,6 +1868,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -1858,6 +1885,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1874,6 +1902,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -1890,6 +1919,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -1906,6 +1936,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -1922,6 +1953,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -1938,6 +1970,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -1954,6 +1987,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -1970,6 +2004,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -1986,6 +2021,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2002,6 +2038,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2021,6 +2058,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -2039,6 +2077,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_1x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -2056,6 +2095,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_gemminc_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-igemm-2.cc b/test/f32-igemm-2.cc index 3e8112cdae53..9b8905795872 100644 --- a/test/f32-igemm-2.cc +++ b/test/f32-igemm-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_5x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_6x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_6x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -675,6 +640,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -690,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -782,6 +754,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_2x4__scalar, xnn_pack_f32_conv_goki_w); diff --git a/test/f32-igemm-minmax-2.cc b/test/f32-igemm-minmax-2.cc index 430192b45ef0..0920ff0d933a 100644 --- a/test/f32-igemm-minmax-2.cc +++ b/test/f32-igemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -306,6 +285,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -314,7 +294,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -324,12 +304,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -457,14 +431,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -490,14 +456,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -598,6 +556,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -607,7 +566,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -617,12 +576,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -730,14 +683,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -763,14 +708,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -878,6 +815,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -897,6 +835,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a7, xnn_init_f32_minmax_scalar_params, @@ -916,6 +855,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -935,6 +875,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_ld64, xnn_init_f32_minmax_scalar_params, @@ -957,6 +898,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -976,6 +918,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -995,6 +938,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1014,6 +958,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1033,6 +978,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_prfm, xnn_init_f32_minmax_scalar_params, @@ -1052,6 +998,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -1071,6 +1018,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -1090,6 +1038,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1109,6 +1058,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1128,6 +1078,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1147,6 +1098,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1166,6 +1118,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1185,6 +1138,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73, xnn_init_f32_minmax_scalar_params, @@ -1204,6 +1158,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -1226,6 +1181,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1245,6 +1201,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1267,6 +1224,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1286,6 +1244,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1305,6 +1264,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1327,6 +1287,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1346,6 +1307,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1368,6 +1330,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1387,6 +1350,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1406,6 +1370,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1428,6 +1393,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1447,6 +1413,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1469,6 +1436,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1488,6 +1456,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x4__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1507,6 +1476,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1529,6 +1499,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1548,6 +1519,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__neon_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1567,6 +1539,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1586,6 +1559,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1605,6 +1579,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__neonfma_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1624,6 +1599,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1643,6 +1619,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1665,6 +1642,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x2__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1687,6 +1665,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x2__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1709,6 +1688,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1731,6 +1711,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__neonfma_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1753,6 +1734,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1772,6 +1754,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1791,6 +1774,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1810,6 +1794,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2c4__sse, xnn_init_f32_minmax_scalar_params, @@ -1829,6 +1814,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1848,6 +1834,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1867,6 +1854,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1886,6 +1874,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1905,6 +1894,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1924,6 +1914,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1943,6 +1934,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1962,6 +1954,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_7x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1981,6 +1974,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2000,6 +1994,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2019,6 +2014,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2038,6 +2034,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2057,6 +2054,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2076,6 +2074,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2095,6 +2094,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2114,6 +2114,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2133,6 +2134,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2152,6 +2154,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x16s4__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2174,6 +2177,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2193,6 +2197,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2212,6 +2217,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2231,6 +2237,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_8x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2250,6 +2257,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2269,6 +2277,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2288,6 +2297,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_7x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2307,6 +2317,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_8x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2329,6 +2340,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2345,6 +2357,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2361,6 +2374,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2377,6 +2391,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2393,6 +2408,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2409,6 +2425,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2425,6 +2442,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2c4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2441,6 +2459,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2c4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2457,6 +2476,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2473,6 +2493,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2489,6 +2510,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2505,6 +2527,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2521,6 +2544,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2537,6 +2561,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2553,6 +2578,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2572,6 +2598,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2588,6 +2615,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -2604,6 +2632,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2620,6 +2649,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2636,6 +2666,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2652,6 +2683,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2668,6 +2700,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2c4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2684,6 +2717,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2700,6 +2734,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2716,6 +2751,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2732,6 +2768,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2748,6 +2785,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -2764,6 +2802,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2780,6 +2819,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2796,6 +2836,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2812,6 +2853,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -2828,6 +2870,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -2844,6 +2887,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2863,6 +2907,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_2x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -2879,6 +2924,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -2895,6 +2941,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -2913,6 +2960,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -2930,6 +2978,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -2947,6 +2996,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -2965,6 +3015,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -2984,6 +3035,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_7x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -3006,6 +3058,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/128, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x128__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3025,6 +3078,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/128, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_2x128__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3044,6 +3098,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_7x64__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3063,6 +3118,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_8x32__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -3082,6 +3138,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/16, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_16x32__hvx_broadcast, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-igemm-minmax.cc b/test/f32-igemm-minmax.cc index dc1352510b93..33f8bb2688ab 100644 --- a/test/f32-igemm-minmax.cc +++ b/test/f32-igemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -306,6 +285,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -314,7 +294,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -324,12 +304,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -457,14 +431,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -490,14 +456,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -598,6 +556,7 @@ std::vector CreateTests2( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -607,7 +566,7 @@ std::vector CreateTests2( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -617,12 +576,6 @@ std::vector CreateTests2( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -730,14 +683,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -763,14 +708,6 @@ std::vector CreateTests2( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -878,6 +815,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -897,6 +835,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -916,6 +855,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -935,6 +875,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -954,6 +895,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -976,6 +918,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -995,6 +938,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/12, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x12__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1014,6 +958,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1033,6 +978,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1052,6 +998,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1071,6 +1018,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1090,6 +1038,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1109,6 +1058,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/12, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x12__asm_aarch64_neonfma_cortex_a53, xnn_init_f32_minmax_scalar_params, @@ -1128,6 +1078,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1147,6 +1098,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm, xnn_init_f32_minmax_scalar_params, @@ -1166,6 +1118,7 @@ std::vector CreateTests2( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -1185,6 +1138,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -1204,6 +1158,7 @@ std::vector CreateTests2( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -1223,6 +1178,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -1245,6 +1201,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_2x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1264,6 +1221,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1283,6 +1241,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1302,6 +1261,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x16__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1324,6 +1284,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_2x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1343,6 +1304,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1362,6 +1324,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1384,6 +1347,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__neon_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1406,6 +1370,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1428,6 +1393,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1447,6 +1413,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1466,6 +1433,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1485,6 +1453,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x2__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1504,6 +1473,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x4__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1523,6 +1493,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1542,6 +1513,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__neon_dup_ld128, xnn_init_f32_minmax_scalar_params, @@ -1561,6 +1533,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1580,6 +1553,7 @@ std::vector CreateTests2( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1599,6 +1573,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1618,6 +1593,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1637,6 +1613,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_8x8s4__neon, xnn_init_f32_minmax_scalar_params, @@ -1656,6 +1633,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_8x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1678,6 +1656,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1697,6 +1676,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1716,6 +1696,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1735,6 +1716,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1754,6 +1736,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1773,6 +1756,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__sse_load1, xnn_init_f32_minmax_scalar_params, @@ -1792,6 +1776,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x2c4__sse, xnn_init_f32_minmax_scalar_params, @@ -1811,6 +1796,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__sse_dup, xnn_init_f32_minmax_scalar_params, @@ -1830,6 +1816,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8s4__sse, xnn_init_f32_minmax_scalar_params, @@ -1849,6 +1836,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1868,6 +1856,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1887,6 +1876,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1906,6 +1896,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1925,6 +1916,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1944,6 +1936,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1963,6 +1956,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1982,6 +1976,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2001,6 +1996,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2020,6 +2016,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2039,6 +2036,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_7x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2058,6 +2056,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_8x8__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2080,6 +2079,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2099,6 +2099,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x16__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2118,6 +2119,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2137,6 +2139,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x32__avx512f_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2159,6 +2162,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2175,6 +2179,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2191,6 +2196,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2207,6 +2213,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2223,6 +2230,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2239,6 +2247,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2255,6 +2264,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2271,6 +2281,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2287,6 +2298,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2303,6 +2315,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2319,6 +2332,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2335,6 +2349,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2351,6 +2366,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2367,6 +2383,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2383,6 +2400,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2399,6 +2417,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2415,6 +2434,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2434,6 +2454,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2450,6 +2471,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2466,6 +2488,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2482,6 +2505,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2498,6 +2522,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -2514,6 +2539,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2530,6 +2556,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2546,6 +2573,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2562,6 +2590,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -2578,6 +2607,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -2594,6 +2624,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2610,6 +2641,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2626,6 +2658,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2642,6 +2675,7 @@ std::vector CreateTests2( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -2661,6 +2695,7 @@ std::vector CreateTests2( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -2679,6 +2714,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -2697,6 +2733,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x32__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2716,6 +2753,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_1x64__hvx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2735,6 +2773,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/64, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_minmax_ukernel_4x64__hvx_broadcast, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-igemm-relu-2.cc b/test/f32-igemm-relu-2.cc index fb61a2c79a42..7ed57e31f730 100644 --- a/test/f32-igemm-relu-2.cc +++ b/test/f32-igemm-relu-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_3x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_3x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x2c4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_5x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_6x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -672,6 +637,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_6x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -690,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -780,6 +752,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_conv_goki_w); @@ -798,6 +771,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_2x4__wasm, xnn_pack_f32_conv_goki_w); @@ -813,6 +787,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x2__wasm, xnn_pack_f32_conv_goki_w); @@ -828,6 +803,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x4__wasm, xnn_pack_f32_conv_goki_w); @@ -845,6 +821,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x4__scalar, xnn_pack_f32_conv_goki_w); @@ -861,6 +838,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x2__scalar, xnn_pack_f32_conv_goki_w); @@ -877,6 +855,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x4__scalar, xnn_pack_f32_conv_goki_w); @@ -894,6 +873,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x4v__rvv, xnn_pack_f32_conv_goki_w); @@ -912,6 +892,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_7x4v__rvv, xnn_pack_f32_conv_goki_w); diff --git a/test/f32-igemm-relu.cc b/test/f32-igemm-relu.cc index 39050788e33a..0212b775edc4 100644 --- a/test/f32-igemm-relu.cc +++ b/test/f32-igemm-relu.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_3x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_5x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_5x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -672,6 +637,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_6x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -690,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -780,6 +752,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -795,6 +768,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -810,6 +784,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -828,6 +803,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_1x4__wasm, xnn_pack_f32_conv_goki_w); @@ -845,6 +821,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_relu_ukernel_2x4__scalar, xnn_pack_f32_conv_goki_w); diff --git a/test/f32-igemm.cc b/test/f32-igemm.cc index 276a544d92f0..60f391033080 100644 --- a/test/f32-igemm.cc +++ b/test/f32-igemm.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -297,6 +276,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -306,7 +286,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -316,12 +296,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -429,14 +403,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -462,14 +428,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -567,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_3x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -582,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_3x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -597,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_3x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -612,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x2c4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -627,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -642,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -657,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_5x8__wasmsimd_loadsplat, xnn_pack_f32_conv_goki_w); @@ -672,6 +637,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_5x8s4__wasmsimd, xnn_pack_f32_conv_goki_w); @@ -687,6 +653,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_6x8__wasmsimd_splat, xnn_pack_f32_conv_goki_w); @@ -705,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -720,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -735,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -750,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -765,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -780,6 +752,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -795,6 +768,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -810,6 +784,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_conv_goki_w); @@ -825,6 +800,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_conv_goki_w); @@ -842,6 +818,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x4__scalar, xnn_pack_f32_conv_goki_w); @@ -858,6 +835,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x2__scalar, xnn_pack_f32_conv_goki_w); @@ -874,6 +852,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_4x4__scalar, xnn_pack_f32_conv_goki_w); @@ -891,6 +870,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_1x4v__rvv, xnn_pack_f32_conv_goki_w); @@ -909,6 +889,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_igemm_ukernel_7x4v__rvv, xnn_pack_f32_conv_goki_w); diff --git a/test/f32-ppmm-minmax.cc b/test/f32-ppmm-minmax.cc index 4b67e49bc9b9..4c78d0a79147 100644 --- a/test/f32-ppmm-minmax.cc +++ b/test/f32-ppmm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128_prfm, xnn_init_f32_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__asm_aarch64_neonfma_cortex_a75, xnn_init_f32_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__asm_aarch64_neonfma_cortex_a75_prfm, xnn_init_f32_minmax_scalar_params, @@ -427,6 +412,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__asm_aarch64_neonfma_ld128_prfm, xnn_init_f32_minmax_scalar_params, @@ -468,6 +455,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__aarch64_neonfma, xnn_init_f32_minmax_scalar_params, @@ -487,6 +475,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__aarch64_neonfma_prfm, xnn_init_f32_minmax_scalar_params, @@ -509,6 +498,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__neon, xnn_init_f32_minmax_scalar_params, @@ -528,6 +518,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__neon_prfm, xnn_init_f32_minmax_scalar_params, @@ -550,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x16__aarch64_neonfma, xnn_init_f32_minmax_scalar_params, @@ -569,6 +561,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x16__aarch64_neonfma_prfm, xnn_init_f32_minmax_scalar_params, @@ -591,6 +584,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x16__neon, xnn_init_f32_minmax_scalar_params, @@ -610,6 +604,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x16__neon_prfm, xnn_init_f32_minmax_scalar_params, @@ -632,6 +627,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__aarch64_neonfma, xnn_init_f32_minmax_scalar_params, @@ -651,6 +647,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__aarch64_neonfma_prfm, xnn_init_f32_minmax_scalar_params, @@ -673,6 +670,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__neon, xnn_init_f32_minmax_scalar_params, @@ -692,6 +690,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_8x8__neon_prfm, xnn_init_f32_minmax_scalar_params, @@ -714,6 +713,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__sse, xnn_init_f32_minmax_scalar_params, @@ -736,6 +736,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -752,6 +753,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -770,6 +772,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -787,6 +790,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/3, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_3x3__scalar, xnn_init_f32_minmax_scalar_params, @@ -804,6 +808,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -821,6 +826,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_ppmm_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-qc4w-gemm-minmax.cc b/test/f32-qc4w-gemm-minmax.cc index b76b7d724d32..4fe50f01d327 100644 --- a/test/f32-qc4w-gemm-minmax.cc +++ b/test/f32-qc4w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -346,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -365,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -384,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc2_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -403,6 +382,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc4, xnn_init_f32_qc4w_minmax_scalar_params, @@ -422,6 +402,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc4_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -441,6 +422,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -460,6 +442,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -479,6 +462,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -501,6 +485,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -520,6 +505,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -542,6 +528,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__neon_dup_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -561,6 +548,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__neon_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -580,6 +568,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__neonfma_dup_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -602,6 +591,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -621,6 +611,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -643,6 +634,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__neon_dup_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -662,6 +654,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__neon_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -681,6 +674,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__neonfma_dup_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -703,6 +697,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x8__aarch64_neonfma_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -725,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x8__neon_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -747,6 +743,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -766,6 +763,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -788,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__neon_dup_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -807,6 +806,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__neon_lane_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -826,6 +826,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__neonfma_dup_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -848,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x8__sse41_dup, xnn_init_f32_qc4w_minmax_scalar_params, @@ -867,6 +869,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_3x8__sse41_dup, xnn_init_f32_qc4w_minmax_scalar_params, @@ -886,6 +889,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x8__sse41_dup, xnn_init_f32_qc4w_minmax_scalar_params, @@ -905,6 +909,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x8__sse41_dup, xnn_init_f32_qc4w_minmax_scalar_params, @@ -924,6 +929,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x8__sse41_dup, xnn_init_f32_qc4w_minmax_scalar_params, @@ -943,6 +949,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -962,6 +969,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_2x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -981,6 +989,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_3x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1000,6 +1009,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1019,6 +1029,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1038,6 +1049,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1057,6 +1069,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_7x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1076,6 +1089,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_8x16__avx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1095,6 +1109,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1114,6 +1129,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_2x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1133,6 +1149,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_3x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1152,6 +1169,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1171,6 +1189,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1190,6 +1209,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1209,6 +1229,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_7x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1228,6 +1249,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_8x16__fma3_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1247,6 +1269,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1266,6 +1289,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_2x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1285,6 +1309,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_3x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1304,6 +1329,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1323,6 +1349,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1342,6 +1369,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1361,6 +1389,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_7x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1380,6 +1409,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_8x16__avx2_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1402,6 +1432,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1421,6 +1452,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_2x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1440,6 +1472,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/3, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_3x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1459,6 +1492,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1478,6 +1512,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_5x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1497,6 +1532,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_6x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1516,6 +1552,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_7x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1535,6 +1572,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_8x32__avx512skx_broadcast, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1557,6 +1595,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x4__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1573,6 +1612,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_2x4__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1589,6 +1629,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x2__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1605,6 +1646,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x4__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1623,6 +1665,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_1x4__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1640,6 +1683,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_2x4__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1657,6 +1701,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x2__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1674,6 +1719,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc4w_gemm_minmax_ukernel_4x4__scalar, xnn_init_f32_qc4w_minmax_scalar_params, diff --git a/test/f32-qc8w-gemm-minmax.cc b/test/f32-qc8w-gemm-minmax.cc index 1901f435494d..04bd46ad903b 100644 --- a/test/f32-qc8w-gemm-minmax.cc +++ b/test/f32-qc8w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neon_ld128_acc2, xnn_init_f32_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neon_ld128_acc2_prfm, xnn_init_f32_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc2, xnn_init_f32_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc2_prfm, xnn_init_f32_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc4, xnn_init_f32_minmax_scalar_params, @@ -427,6 +412,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_acc4_prfm, xnn_init_f32_minmax_scalar_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64_prfm, xnn_init_f32_minmax_scalar_params, @@ -465,6 +452,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -484,6 +472,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2, xnn_init_f32_minmax_scalar_params, @@ -503,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2_prfm, xnn_init_f32_minmax_scalar_params, @@ -522,6 +512,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4, xnn_init_f32_minmax_scalar_params, @@ -541,6 +532,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4_prfm, xnn_init_f32_minmax_scalar_params, @@ -560,6 +552,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_prfm, xnn_init_f32_minmax_scalar_params, @@ -579,6 +572,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/1, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x1__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -598,6 +592,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/1, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x1__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -617,6 +612,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -636,6 +632,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -655,6 +652,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -674,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -693,6 +692,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld64, xnn_init_f32_minmax_scalar_params, @@ -712,6 +712,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128, xnn_init_f32_minmax_scalar_params, @@ -734,6 +735,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -753,6 +755,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -775,6 +778,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -794,6 +798,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -813,6 +818,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -835,6 +841,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -854,6 +861,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -876,6 +884,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -898,6 +907,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -917,6 +927,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -939,6 +950,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -958,6 +970,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -977,6 +990,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -999,6 +1013,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x16__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1018,6 +1033,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1040,6 +1056,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1062,6 +1079,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x2__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1084,6 +1102,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x2__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1106,6 +1125,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1125,6 +1145,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128, xnn_init_f32_minmax_scalar_params, @@ -1147,6 +1168,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__neon_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1166,6 +1188,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__neon_lane_ld64, xnn_init_f32_minmax_scalar_params, @@ -1185,6 +1208,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__neonfma_dup_ld64, xnn_init_f32_minmax_scalar_params, @@ -1204,6 +1228,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1223,6 +1248,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1242,6 +1268,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8s4__neonfma, xnn_init_f32_minmax_scalar_params, @@ -1264,6 +1291,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__sse41_dup, xnn_init_f32_minmax_scalar_params, @@ -1283,6 +1311,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__sse41_load1, xnn_init_f32_minmax_scalar_params, @@ -1302,6 +1331,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8s4__sse41, xnn_init_f32_minmax_scalar_params, @@ -1321,6 +1351,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__sse41_dup, xnn_init_f32_minmax_scalar_params, @@ -1340,6 +1371,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__sse41_load1, xnn_init_f32_minmax_scalar_params, @@ -1359,6 +1391,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8s4__sse41, xnn_init_f32_minmax_scalar_params, @@ -1378,6 +1411,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2c4__sse41, xnn_init_f32_minmax_scalar_params, @@ -1397,6 +1431,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__sse41_dup, xnn_init_f32_minmax_scalar_params, @@ -1416,6 +1451,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__sse41_load1, xnn_init_f32_minmax_scalar_params, @@ -1435,6 +1471,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8s4__sse41, xnn_init_f32_minmax_scalar_params, @@ -1454,6 +1491,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__sse41_dup, xnn_init_f32_minmax_scalar_params, @@ -1473,6 +1511,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__sse41_load1, xnn_init_f32_minmax_scalar_params, @@ -1492,6 +1531,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8s4__sse41, xnn_init_f32_minmax_scalar_params, @@ -1511,6 +1551,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__sse41_dup, xnn_init_f32_minmax_scalar_params, @@ -1530,6 +1571,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__sse41_load1, xnn_init_f32_minmax_scalar_params, @@ -1549,6 +1591,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8s4__sse41, xnn_init_f32_minmax_scalar_params, @@ -1568,6 +1611,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1587,6 +1631,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1606,6 +1651,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1625,6 +1671,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1644,6 +1691,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1663,6 +1711,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1682,6 +1731,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_7x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1701,6 +1751,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_8x16__avx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1720,6 +1771,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1739,6 +1791,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1758,6 +1811,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1777,6 +1831,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1796,6 +1851,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1815,6 +1871,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1834,6 +1891,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_7x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1853,6 +1911,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_8x16__fma3_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1872,6 +1931,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1891,6 +1951,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x16s4__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1910,6 +1971,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x16s4__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1929,6 +1991,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x16s4__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1948,6 +2011,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1967,6 +2031,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x16s4__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -1986,6 +2051,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2005,6 +2071,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x16s4__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2024,6 +2091,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2043,6 +2111,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x16s4__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2062,6 +2131,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_7x8__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2081,6 +2151,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_8x8__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2100,6 +2171,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2119,6 +2191,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2138,6 +2211,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2157,6 +2231,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2176,6 +2251,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2195,6 +2271,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2214,6 +2291,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_7x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2233,6 +2311,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_8x16__avx2_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2255,6 +2334,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2274,6 +2354,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2293,6 +2374,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2312,6 +2394,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2331,6 +2414,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2350,6 +2434,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2369,6 +2454,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2388,6 +2474,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2407,6 +2494,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2426,6 +2514,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2445,6 +2534,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2464,6 +2554,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2483,6 +2574,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_7x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2502,6 +2594,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_7x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2521,6 +2614,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_8x16__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2540,6 +2634,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_8x32__avx512skx_broadcast, xnn_init_f32_minmax_scalar_params, @@ -2562,6 +2657,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2578,6 +2674,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2594,6 +2691,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2610,6 +2708,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2626,6 +2725,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2642,6 +2742,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2658,6 +2759,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2674,6 +2776,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2690,6 +2793,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2706,6 +2810,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2722,6 +2827,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2738,6 +2844,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2754,6 +2861,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2c4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2770,6 +2878,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2c4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2786,6 +2895,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2802,6 +2912,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2818,6 +2929,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2834,6 +2946,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2850,6 +2963,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2866,6 +2980,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2882,6 +2997,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2898,6 +3014,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -2914,6 +3031,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2930,6 +3048,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -2946,6 +3065,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -2962,6 +3082,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -2978,6 +3099,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmsimd_arm_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -2994,6 +3116,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmsimd_arm_splat, xnn_init_f32_minmax_scalar_params, @@ -3010,6 +3133,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmsimd_x86_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3026,6 +3150,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmsimd_x86_splat, xnn_init_f32_minmax_scalar_params, @@ -3042,6 +3167,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8s4__wasmsimd_arm, xnn_init_f32_minmax_scalar_params, @@ -3058,6 +3184,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8s4__wasmsimd_x86, xnn_init_f32_minmax_scalar_params, @@ -3077,6 +3204,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3093,6 +3221,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3109,6 +3238,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3125,6 +3255,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3141,6 +3272,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3157,6 +3289,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3173,6 +3306,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3189,6 +3323,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3205,6 +3340,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3221,6 +3357,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3237,6 +3374,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3253,6 +3391,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3269,6 +3408,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2c4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3285,6 +3425,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3301,6 +3442,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3317,6 +3459,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3333,6 +3476,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3349,6 +3493,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3365,6 +3510,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3381,6 +3527,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3397,6 +3544,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3413,6 +3561,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3429,6 +3578,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3445,6 +3595,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3461,6 +3612,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3477,6 +3629,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3493,6 +3646,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3509,6 +3663,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_init_f32_minmax_scalar_params, @@ -3525,6 +3680,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_loadsplat, xnn_init_f32_minmax_scalar_params, @@ -3541,6 +3697,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_splat, xnn_init_f32_minmax_scalar_params, @@ -3557,6 +3714,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8s4__wasmrelaxedsimd, xnn_init_f32_minmax_scalar_params, @@ -3573,6 +3731,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_init_f32_minmax_scalar_params, @@ -3592,6 +3751,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -3608,6 +3768,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -3624,6 +3785,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -3640,6 +3802,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -3658,6 +3821,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_1x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -3675,6 +3839,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -3692,6 +3857,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -3709,6 +3875,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/f32-qc8w-gemm-relu.cc b/test/f32-qc8w-gemm-relu.cc index 34b2b3dabbcc..91ab25600535 100644 --- a/test/f32-qc8w-gemm-relu.cc +++ b/test/f32-qc8w-gemm-relu.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -318,6 +298,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -333,6 +314,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -348,6 +330,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -363,6 +346,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -378,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -393,6 +378,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -408,6 +394,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -423,6 +410,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -438,6 +426,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -453,6 +442,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -468,6 +458,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -483,6 +474,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -498,6 +490,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -513,6 +506,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -528,6 +522,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -546,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_1x4__wasm, xnn_pack_f32_qs8w_gemm_goi_w); @@ -561,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_2x4__wasm, xnn_pack_f32_qs8w_gemm_goi_w); @@ -576,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x2__wasm, xnn_pack_f32_qs8w_gemm_goi_w); @@ -591,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x4__wasm, xnn_pack_f32_qs8w_gemm_goi_w); @@ -608,6 +607,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_1x4__scalar, xnn_pack_f32_qs8w_gemm_goi_w); @@ -624,6 +624,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_2x4__scalar, xnn_pack_f32_qs8w_gemm_goi_w); @@ -640,6 +641,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x2__scalar, xnn_pack_f32_qs8w_gemm_goi_w); @@ -656,6 +658,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_relu_ukernel_4x4__scalar, xnn_pack_f32_qs8w_gemm_goi_w); diff --git a/test/f32-qc8w-gemm.cc b/test/f32-qc8w-gemm.cc index 00262887e3f2..7558f1f6cef2 100644 --- a/test/f32-qc8w-gemm.cc +++ b/test/f32-qc8w-gemm.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x8__wasmsimd_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -318,6 +298,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x8__wasmsimd_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -333,6 +314,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x8s4__wasmsimd, xnn_pack_f32_qs8w_gemm_goi_w); @@ -348,6 +330,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_3x8__wasmsimd_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -363,6 +346,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_3x8__wasmsimd_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -378,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_3x8s4__wasmsimd, xnn_pack_f32_qs8w_gemm_goi_w); @@ -393,6 +378,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x2c4__wasmsimd, xnn_pack_f32_qs8w_gemm_goi_w); @@ -408,6 +394,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x8__wasmsimd_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -423,6 +410,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x8__wasmsimd_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -438,6 +426,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x8s4__wasmsimd, xnn_pack_f32_qs8w_gemm_goi_w); @@ -453,6 +442,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_5x8__wasmsimd_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -468,6 +458,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_5x8__wasmsimd_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -483,6 +474,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_5x8s4__wasmsimd, xnn_pack_f32_qs8w_gemm_goi_w); @@ -498,6 +490,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_6x8__wasmsimd_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -513,6 +506,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_6x8__wasmsimd_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -528,6 +522,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_6x8s4__wasmsimd, xnn_pack_f32_qs8w_gemm_goi_w); @@ -546,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -561,6 +557,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -576,6 +573,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -591,6 +589,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_3x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -606,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_3x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -621,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_3x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -636,6 +637,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x2c4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -651,6 +653,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -666,6 +669,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -681,6 +685,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -696,6 +701,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_5x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -711,6 +717,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_5x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -726,6 +733,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_5x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -741,6 +749,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_6x8__wasmrelaxedsimd_fma_loadsplat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -756,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_6x8__wasmrelaxedsimd_fma_splat, xnn_pack_f32_qs8w_gemm_goi_w); @@ -771,6 +781,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_6x8s4__wasmrelaxedsimd_fma, xnn_pack_f32_qs8w_gemm_goi_w); @@ -788,6 +799,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_1x4__scalar, xnn_pack_f32_qs8w_gemm_goi_w); @@ -804,6 +816,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_2x4__scalar, xnn_pack_f32_qs8w_gemm_goi_w); @@ -820,6 +833,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x2__scalar, xnn_pack_f32_qs8w_gemm_goi_w); @@ -836,6 +850,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_f32_qc8w_gemm_ukernel_4x4__scalar, xnn_pack_f32_qs8w_gemm_goi_w); diff --git a/test/f32-raddextexp.cc b/test/f32-raddextexp.cc index afd9bc0a0215..4ee083e0eece 100644 --- a/test/f32-raddextexp.cc +++ b/test/f32-raddextexp.cc @@ -2,902 +2,116 @@ // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -// -// Auto-generated file. Do not edit! -// Specification: test/f32-raddextexp.yaml -// Generator: tools/generate-raddextexp-test.py +#include +#include +#include +#include +#include +#include +#include +#include + #include + +#include "replicable_random_device.h" +#include "xnnpack.h" +#include "xnnpack/buffer.h" #include "xnnpack/common.h" #include "xnnpack/isa-checks.h" +#include "xnnpack/microfnptr.h" #include "xnnpack/raddextexp.h" -#include "raddextexp-microkernel-tester.h" - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U64, elements_eq_64) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(64) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64, elements_div_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 128; elements < 640; elements += 64) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64, elements_lt_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 64; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64, elements_gt_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 65; elements < 128; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC2, elements_eq_64) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(64) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC2, elements_div_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 128; elements < 640; elements += 64) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC2, elements_lt_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 64; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC2, elements_gt_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 65; elements < 128; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC4, elements_eq_64) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(64) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC4, elements_div_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 128; elements < 640; elements += 64) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC4, elements_lt_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 64; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U64_ACC4, elements_gt_64) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 65; elements < 128; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U72, elements_eq_72) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(72) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U72, elements_div_72) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 144; elements < 720; elements += 72) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U72, elements_lt_72) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 72; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U72, elements_gt_72) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 73; elements < 144; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U72_ACC3, elements_eq_72) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(72) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U72_ACC3, elements_div_72) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 144; elements < 720; elements += 72) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U72_ACC3, elements_lt_72) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 72; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U72_ACC3, elements_gt_72) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 73; elements < 144; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U80, elements_eq_80) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(80) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80, elements_div_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 160; elements < 800; elements += 80) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80, elements_lt_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 80; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80, elements_gt_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 81; elements < 160; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC2, elements_eq_80) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(80) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC2, elements_div_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 160; elements < 800; elements += 80) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC2, elements_lt_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 80; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC2, elements_gt_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 81; elements < 160; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC5, elements_eq_80) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(80) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC5, elements_div_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 160; elements < 800; elements += 80) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC5, elements_lt_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 80; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U80_ACC5, elements_gt_80) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 81; elements < 160; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U96, elements_eq_96) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(96) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96, elements_div_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 192; elements < 960; elements += 96) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96, elements_lt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 96; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96, elements_gt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 97; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC2, elements_eq_96) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(96) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC2, elements_div_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 192; elements < 960; elements += 96) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC2, elements_lt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 96; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC2, elements_gt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 97; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC3, elements_eq_96) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(96) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC3, elements_div_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 192; elements < 960; elements += 96) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC3, elements_lt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 96; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC3, elements_gt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 97; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC6, elements_eq_96) { - TEST_REQUIRES_X86_AVX2; - RAddExtExpMicrokernelTester() - .elements(96) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6); - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC6, elements_div_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 192; elements < 960; elements += 96) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC6, elements_lt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 1; elements < 96; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6); - } - } - - TEST(F32_RADDEXTEXP__AVX2_P5_U96_ACC6, elements_gt_96) { - TEST_REQUIRES_X86_AVX2; - for (size_t elements = 97; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6); - } - } -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128, elements_eq_128) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(128) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128); - } - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128, elements_div_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 256; elements < 1280; elements += 128) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128, elements_lt_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 128; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128, elements_gt_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 129; elements < 256; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC2, elements_eq_128) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(128) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC2, elements_div_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 256; elements < 1280; elements += 128) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC2, elements_lt_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 128; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC2, elements_gt_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 129; elements < 256; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC4, elements_eq_128) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(128) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC4, elements_div_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 256; elements < 1280; elements += 128) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC4, elements_lt_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 128; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U128_ACC4, elements_gt_128) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 129; elements < 256; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144, elements_eq_144) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(144) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144, elements_div_144) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 288; elements < 1440; elements += 144) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144, elements_lt_144) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 144; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144, elements_gt_144) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 145; elements < 288; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144_ACC3, elements_eq_144) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(144) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144_ACC3, elements_div_144) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 288; elements < 1440; elements += 144) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144_ACC3, elements_lt_144) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 144; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U144_ACC3, elements_gt_144) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 145; elements < 288; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160, elements_eq_160) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(160) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160, elements_div_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 320; elements < 1600; elements += 160) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160, elements_lt_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 160; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160, elements_gt_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 161; elements < 320; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC2, elements_eq_160) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(160) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC2, elements_div_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 320; elements < 1600; elements += 160) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC2, elements_lt_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 160; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC2, elements_gt_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 161; elements < 320; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC5, elements_eq_160) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(160) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC5, elements_div_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 320; elements < 1600; elements += 160) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC5, elements_lt_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 160; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U160_ACC5, elements_gt_160) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 161; elements < 320; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192, elements_eq_192) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(192) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192, elements_div_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 384; elements < 1920; elements += 192) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192, elements_lt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192, elements_gt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 193; elements < 384; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC2, elements_eq_192) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(192) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC2, elements_div_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 384; elements < 1920; elements += 192) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC2, elements_lt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC2, elements_gt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 193; elements < 384; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC3, elements_eq_192) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(192) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC3, elements_div_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 384; elements < 1920; elements += 192) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC3, elements_lt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC3, elements_gt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 193; elements < 384; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - - -#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC6, elements_eq_192) { - TEST_REQUIRES_X86_AVX512F; - RAddExtExpMicrokernelTester() - .elements(192) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6); - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC6, elements_div_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 384; elements < 1920; elements += 192) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC6, elements_lt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 1; elements < 192; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6); - } - } - - TEST(F32_RADDEXTEXP__AVX512F_P5_SCALEF_U192_ACC6, elements_gt_192) { - TEST_REQUIRES_X86_AVX512F; - for (size_t elements = 193; elements < 384; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6); - } - } -#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +class RAddExtExpMicrokernelTester { + public: + RAddExtExpMicrokernelTester& elements(size_t elements) { + assert(elements != 0); + this->elements_ = elements; + return *this; + } + + size_t elements() const { + return this->elements_; + } + + RAddExtExpMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + size_t iterations() const { + return this->iterations_; + } + + void Test(xnn_f32_raddextexp_ukernel_fn raddextexp) const { + xnnpack::ReplicableRandomDevice rng; + // Choose such range that expf(x[i]) overflows, but double-precision exp doesn't overflow. + auto f32rng = [&rng]() { + return std::uniform_real_distribution(90.0f, 100.0f)(rng); + }; + + xnnpack::Buffer x(elements() + XNN_EXTRA_BYTES / sizeof(float)); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(f32rng)); + + // Compute reference results. + double sum_ref = 0.0f; + for (size_t i = 0; i < elements(); i++) { + sum_ref += exp(double(x[i])); + } + + // Call optimized micro-kernel. + float sum[2]; + raddextexp(elements() * sizeof(float), x.data(), sum); + + // Verify results. + ASSERT_NEAR(sum_ref, exp2(double(sum[1])) * double(sum[0]), std::abs(sum_ref) * 1.0e-6) + << "elements = " << elements() << ", y:value = " << sum[0] << ", y:exponent = " << sum[1]; + } + } + + private: + size_t elements_{1}; + size_t iterations_{15}; +}; + +#define XNN_TEST_RADDEXTEXP_ELEMENT_EQ(ukernel, arch_flags, element_tile, ...) \ + TEST(ukernel, element_eq) \ + { \ + TEST_REQUIRES_ARCH_FLAGS(arch_flags); \ + RAddExtExpMicrokernelTester().elements(element_tile).Test(ukernel); \ + } +#define XNN_TEST_RADDEXTEXP_ELEMENT_DIV(ukernel, arch_flags, element_tile, ...) \ + TEST(ukernel, element_gt) \ + { \ + TEST_REQUIRES_ARCH_FLAGS(arch_flags); \ + for (size_t element_size = element_tile * 2; element_size < element_tile * 10; element_size += element_tile) { \ + RAddExtExpMicrokernelTester().elements(element_size).Test(ukernel); \ + } \ + } +#define XNN_TEST_RADDEXTEXP_ELEMENT_LT(ukernel, arch_flags, element_tile, ...) \ + TEST(ukernel, element_lt) \ + { \ + TEST_REQUIRES_ARCH_FLAGS(arch_flags); \ + for (size_t element_size = 1; element_size < element_tile; element_size++) { \ + RAddExtExpMicrokernelTester().elements(element_size).Test(ukernel); \ + } \ + } +#define XNN_TEST_RADDEXTEXP_ELEMENT_GT(ukernel, arch_flags, element_tile, ...) \ + TEST(ukernel, element_div) \ + { \ + TEST_REQUIRES_ARCH_FLAGS(arch_flags); \ + for (size_t element_size = element_tile + 1; element_size < (element_tile == 1 ? 10 : element_tile * 2); \ + element_size++) { \ + RAddExtExpMicrokernelTester().elements(element_size).Test(ukernel); \ + } \ + } + +#define XNN_UKERNEL_WITH_PARAMS(arch_flags, ukernel, element_tile, datatype, params_type, init_params) \ + XNN_TEST_RADDEXTEXP_ELEMENT_EQ(ukernel, arch_flags, element_tile, init_params); \ + XNN_TEST_RADDEXTEXP_ELEMENT_DIV(ukernel, arch_flags, element_tile, init_params); \ + XNN_TEST_RADDEXTEXP_ELEMENT_LT(ukernel, arch_flags, element_tile, init_params); \ + XNN_TEST_RADDEXTEXP_ELEMENT_GT(ukernel, arch_flags, element_tile, init_params); +#include "f32-raddextexp/f32-raddextexp.h" +#undef XNN_UKERNEL_WITH_PARAMS diff --git a/test/f32-raddextexp.yaml b/test/f32-raddextexp.yaml deleted file mode 100644 index ac18103efef7..000000000000 --- a/test/f32-raddextexp.yaml +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2019 Google LLC -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# x86 AVX -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u64 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc2 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u64_acc4 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u72 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u72_acc3 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u80 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc2 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u80_acc5 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u96 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc2 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc3 -- name: xnn_f32_raddextexp_ukernel__avx2_p5_u96_acc6 -# x86 AVX512 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc2 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u128_acc4 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc2 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u160_acc5 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc2 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc3 -- name: xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u192_acc6 diff --git a/test/fully-connected-nc.cc b/test/fully-connected-nc.cc index 2a89078fa518..7682ebb755fa 100644 --- a/test/fully-connected-nc.cc +++ b/test/fully-connected-nc.cc @@ -2258,3 +2258,182 @@ TEST(FULLY_CONNECTED_NC_QP8_F32_QB4W, small_batch_with_output_stride) { .iterations(3) .TestQP8F32QB4W(); } + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch_with_qmin) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .qmin(128) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch_with_qmax) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .qmax(128) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch_with_input_stride) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(22) + .input_stride(28) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch_with_output_stride) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .output_stride(29) + .iterations(3) + .TestQP8F32QC8W(); +} + +// TODO(b/355416339): Re-enable once we can handle strides again +TEST(DISABLED_FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch_transpose_weights) { + FullyConnectedOperatorTester() + .transpose_weights(true) + .batch_size(1) + .input_channels(22) + .output_channels(20) // legacy requires even number + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, unit_batch_without_bias) { + FullyConnectedOperatorTester() + .has_bias(false) + .batch_size(1) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch_with_qmin) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .qmin(128) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch_with_qmax) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .qmax(128) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch_with_input_stride) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(22) + .input_stride(29) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch_with_output_stride) { + FullyConnectedOperatorTester() + .batch_size(12) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .output_stride(29) + .iterations(3) + .TestQP8F32QC8W(); +} + +// TODO(b/355416339): Re-enable once we can handle strides again +TEST(DISABLED_FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch_transpose_weights) { + FullyConnectedOperatorTester() + .transpose_weights(true) + .batch_size(12) + .input_channels(22) + .output_channels(20) // legacy doesn't support odd nc + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, small_batch_without_bias) { + FullyConnectedOperatorTester() + .has_bias(false) + .batch_size(12) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .iterations(3) + .TestQP8F32QC8W(); +} + +TEST(FULLY_CONNECTED_NC_QP8_F32_QC8W, weights_cache_unit_batch) { + FullyConnectedOperatorTester() + .batch_size(1) + .input_channels(22) + .output_channels(19) + .kernel_zero_point(8) + .use_weights_cache(true) + .iterations(3) + .TestQP8F32QC8W(); +} + +// TODO(b/355416339): Re-enable once we can handle strides again +TEST(DISABLED_FULLY_CONNECTED_NC_QP8_F32_QC8W, weights_cache_unit_batch_transpose_weights) { + FullyConnectedOperatorTester() + .transpose_weights(true) + .batch_size(1) + .input_channels(22) + .output_channels(20) // legacy doesn't support odd nc + .kernel_zero_point(8) + .use_weights_cache(true) + .iterations(3) + .TestQP8F32QC8W(); +} + diff --git a/test/fully-connected-operator-tester.h b/test/fully-connected-operator-tester.h index 5ed95e7a0464..fa7e46707981 100644 --- a/test/fully-connected-operator-tester.h +++ b/test/fully-connected-operator-tester.h @@ -628,7 +628,7 @@ class FullyConnectedOperatorTester { xnnpack::Buffer quantization_params(batch_size() + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels()); - for (size_t iteration = 0; iteration < iterations(); iteration++) { + { // for (size_t iteration = 0; iteration < iterations(); iteration++) { std::generate(input.begin(), input.end(), [&]() { return w8dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); @@ -698,7 +698,6 @@ class FullyConnectedOperatorTester { value = std::max(std::min(value, output_max), output_min); } - // Create, setup, run, and destroy Fully Connected operator. ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t fully_connected_op = nullptr; @@ -1212,6 +1211,213 @@ class FullyConnectedOperatorTester { } } + void TestQP8F32QC8W() const { + // Get the parameters of this GEMM, skip if not available. + const struct xnn_gemm_config* gemm_config = + xnn_init_qp8_f32_qc8w_gemm_config(); + if (gemm_config == nullptr) { + GTEST_SKIP(); + } + + // Note that the microkernel will force `mr` to 1 if `mc` is 1, so we have + // to anticipate that when packing the left-hand operand. + const uint32_t mr_packed = batch_size() > 1 ? gemm_config->mr_packed : 1; + const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr; + const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr; + + ASSERT_EQ(weights_type(), WeightsType::Default); + + xnnpack::ReplicableRandomDevice rng; + std::uniform_real_distribution f32dist(-1.f, 1.f); + std::uniform_real_distribution f32idist(0.5f, 2.0f); + std::uniform_int_distribution w8dist( + std::numeric_limits::min(), std::numeric_limits::max()); + + const size_t k = input_channels(); + + xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(float) + + (batch_size() - 1) * input_stride() + + input_channels()); + const size_t kernel_stride = transpose_weights() ? output_channels() : k; + xnnpack::Buffer kernel(k * output_channels()); + xnnpack::Buffer bias(output_channels()); + xnnpack::Buffer output((batch_size() - 1) * output_stride() + + output_channels()); + xnnpack::Buffer output_ref(batch_size() * output_channels()); + xnnpack::Buffer kernel_scale(output_channels()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + std::generate(kernel.begin(), kernel.end(), + [&]() { return w8dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return f32idist(rng); }); + std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); + + // Quantize the left-hand operand. + const size_t input_packed_size = + xnn_x8_packq_f32qp8_packed_size(batch_size(), k, mr_packed, kr, sr); + xnnpack::Buffer input_qp8(input_packed_size); + xnn_x8_packq_f32qp8_ukernel__scalar_u1(batch_size(), k, mr_packed, kr, sr, + /*m_idx_start=*/0, input.data(), + /*lhs_stride=*/k * sizeof(float), + input_qp8.data()); + + // Compute reference results, without renormalization. + std::fill(output_ref.begin(), output_ref.end(), 0); + if (transpose_weights()) { + for (size_t mi = 0; mi < batch_size(); mi++) { + for (size_t ni = 0; ni < output_channels(); ni++) { + for (size_t ki = 0; ki < input_channels(); ki++) { + const size_t kernel_index = ki * kernel_stride + ni; + int8_t kernel_value = kernel[kernel_index]; + output_ref[mi * output_channels() + ni] += + xnn_x8_packq_f32qp8_get_dequantized(mi, ki, input_qp8.data(), + k, mr_packed, kr, sr) * + static_cast(static_cast(kernel_value)); + } + output_ref[mi * output_channels() + ni] *= kernel_scale[ni]; + if (has_bias()) { + output_ref[mi * output_channels() + ni] += bias[ni]; + } + } + } + } else { + for (size_t mi = 0; mi < batch_size(); mi++) { + for (size_t ni = 0; ni < output_channels(); ni++) { + for (size_t ki = 0; ki < input_channels(); ki++) { + const size_t kernel_index = ni * kernel_stride + ki; + int8_t kernel_value = kernel[kernel_index]; + output_ref[mi * output_channels() + ni] += + xnn_x8_packq_f32qp8_get_dequantized(mi, ki, input_qp8.data(), + k, mr_packed, kr, sr) * + static_cast(static_cast(kernel_value)); + } + output_ref[mi * output_channels() + ni] *= kernel_scale[ni]; + if (has_bias()) { + output_ref[mi * output_channels() + ni] += bias[ni]; + } + } + } + } + + // Compute clamping parameters. + const float accumulated_max = + *std::max_element(output_ref.cbegin(), output_ref.cend()); + const float accumulated_min = + *std::min_element(output_ref.cbegin(), output_ref.cend()); + + const float output_min = + qmin() == 0 + ? -std::numeric_limits::infinity() + : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * + static_cast(qmin()); + const float output_max = + qmax() == 255 + ? std::numeric_limits::infinity() + : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * + static_cast(255 - qmax()); + + // Clamp reference results. + for (float& value : output_ref) { + value = std::max(std::min(value, output_max), output_min); + } + + // Create, setup, run, and destroy Fully Connected operator. + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + xnn_operator_t fully_connected_op = nullptr; + + struct xnn_internal_weights_cache* internal_weights_cache = nullptr; + std::unique_ptr + auto_weights_cache(nullptr, xnn_delete_weights_cache); + if (use_weights_cache()) { + xnn_weights_cache_t weights_cache = nullptr; + xnn_create_weights_cache(&weights_cache); + auto_weights_cache.reset(weights_cache); + if (weights_cache) { + internal_weights_cache = + (struct xnn_internal_weights_cache*)weights_cache->context; + } + } + + const xnn_status status = xnn_create_fully_connected_nc_qp8_f32_qc8w( + input_channels(), output_channels(), input_stride(), output_stride(), + kernel_scale.data(), kernel.data(), + has_bias() ? bias.data() : nullptr, output_min, output_max, + transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, nullptr, + auto_weights_cache.get(), &fully_connected_op); + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, fully_connected_op); + if (use_weights_cache()) { + ASSERT_EQ(xnn_status_success, + xnn_finalize_weights_cache( + auto_weights_cache.get(), + xnn_weights_cache_finalization_kind_soft)); + } + + // Smart pointer to automatically delete fully_connected_op. + std::unique_ptr + auto_fully_connected_op(fully_connected_op, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qc8w( + fully_connected_op, batch_size(), + /*threadpool=*/nullptr)); + + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qp8_f32_qc8w( + fully_connected_op, input_qp8.data(), output.data())); + + ASSERT_EQ(xnn_status_success, + xnn_run_operator(fully_connected_op, /*threadpool=*/nullptr)); + + // Verify results. + VerifyF32(output, output_ref, output_max, output_min); + + if (use_weights_cache()) { + // Create another operator with the same weights cache. + xnn_operator_t fully_connected_op2 = nullptr; + size_t old_weights_cache_size = + internal_weights_cache->cache.weights.size; + + ASSERT_EQ( + xnn_status_success, + xnn_create_fully_connected_nc_qp8_f32_qc8w( + input_channels(), output_channels(), input_stride(), + output_stride(), kernel_scale.data(), kernel.data(), + has_bias() ? bias.data() : nullptr, output_min, output_max, + transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, nullptr, + auto_weights_cache.get(), &fully_connected_op2)); + ASSERT_NE(nullptr, fully_connected_op2); + + // Smart pointer to automatically delete fully_connected_op. + std::unique_ptr + auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); + + ASSERT_EQ(xnn_status_success, + xnn_reshape_fully_connected_nc_qp8_f32_qc8w( + fully_connected_op2, batch_size(), + /*threadpool=*/nullptr)); + + xnnpack::Buffer output2(output.size()); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qp8_f32_qc8w( + fully_connected_op2, input_qp8.data(), output2.data())); + + ASSERT_EQ(xnn_status_success, xnn_run_operator(fully_connected_op2, + /*threadpool=*/nullptr)); + + VerifyWeightsCache(*internal_weights_cache, old_weights_cache_size); + + VerifyF32(output, output_ref, output_max, output_min); + } + } + } + void TestQP8F32QB4W() const { // Get the parameters of this GEMM, skip if not available. const struct xnn_gemm_config* gemm_config = @@ -1614,6 +1820,7 @@ class FullyConnectedOperatorTester { std::uniform_real_distribution f32idist(0.5f, 2.0f); std::uniform_int_distribution w8dist(std::numeric_limits::min(), std::numeric_limits::max()); + // Need to adjust input and quantization_parmams xnnpack::Buffer input(XNN_EXTRA_BYTES / sizeof(int8_t) + (batch_size() - 1) * input_stride() + input_channels()); xnnpack::Buffer kernel(output_channels() * input_channels()); diff --git a/test/fully-connected.cc b/test/fully-connected.cc index 435701e3e97a..55de4b9e1990 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -19,31 +19,26 @@ #include #include #include "xnnpack.h" +#include "xnnpack/buffer.h" #include "xnnpack/common.h" -#include "xnnpack/config.h" -#include "xnnpack/internal.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" #include "xnnpack/operator.h" -#include "xnnpack/packq.h" #include "xnnpack/requantization.h" #include "xnnpack/subgraph.h" -#include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "runtime-flags.h" using testing::ElementsAreArray; struct FullyConnectedTestParam { - FullyConnectedTestParam(bool use_bias_, int8_t kernel_zero_point_) - : use_bias(use_bias_), kernel_zero_point(kernel_zero_point_) {} explicit FullyConnectedTestParam(bool use_bias_) : use_bias(use_bias_) {} bool use_bias; - int8_t kernel_zero_point; }; template + int WeightsPerElement = 1> class FullyConnectedTestBase : public ::testing::TestWithParam { protected: @@ -51,12 +46,13 @@ class FullyConnectedTestBase f32dist = std::uniform_real_distribution(0.1f, 1.0f); scale_dist = std::uniform_real_distribution(1.0f, 5.0f); i32dist = std::uniform_int_distribution(-10000, 10000); - auto shape_dist = std::uniform_int_distribution(2, XNN_MAX_TENSOR_DIMS); + auto shape_dist = std::uniform_int_distribution(2, 3); dim_dist = std::uniform_int_distribution(5, 15); - i8dist = - std::uniform_int_distribution(std::numeric_limits::min(), std::numeric_limits::max()); - w8dist = - std::uniform_int_distribution(-std::numeric_limits::max(), std::numeric_limits::max()); + i8dist = std::uniform_int_distribution( + std::numeric_limits::min(), std::numeric_limits::max()); + w8dist = std::uniform_int_distribution( + -std::numeric_limits::max(), + std::numeric_limits::max()); output_min = -std::numeric_limits::infinity(); output_max = std::numeric_limits::infinity(); @@ -66,11 +62,8 @@ class FullyConnectedTestBase assert(input_dims.size() >= 2); output_channels = dim_dist(rng); input_channels = input_dims.back(); - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - if (even_channels) { - input_channels = round_up_po2(input_channels, 2); - input_dims.back() = input_channels; - } + input_channels = round_up_po2(input_channels, WeightsPerElement); + input_dims.back() = input_channels; kernel_dims = {output_channels, input_channels}; kernel_dims_tranposed = {input_channels, output_channels}; bias_dims = {output_channels}; @@ -79,9 +72,12 @@ class FullyConnectedTestBase batch_size = NumElements(input_dims) / input_channels; - input = xnnpack::Buffer(XNN_EXTRA_BYTES / sizeof(InputType) + NumElements(input_dims)); - kernel = xnnpack::Buffer(input_channels * output_channels); - kernel_fp16 = xnnpack::Buffer(input_channels * output_channels); + input = xnnpack::Buffer(XNN_EXTRA_BYTES / sizeof(InputType) + + NumElements(input_dims)); + kernel = xnnpack::Buffer(input_channels * output_channels / + WeightsPerElement); + kernel_fp16 = + xnnpack::Buffer(input_channels * output_channels); bias = xnnpack::Buffer(output_channels); bias_fp16 = xnnpack::Buffer(output_channels); operator_output = xnnpack::Buffer(NumElements(output_dims)); @@ -89,16 +85,15 @@ class FullyConnectedTestBase accumulators = xnnpack::Buffer(batch_size * output_channels); } - std::vector RandomShape(size_t num_dims) - { + std::vector RandomShape(size_t num_dims) { std::vector dims(num_dims); std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); }); return dims; } - size_t NumElements(std::vector& dims) - { - return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies()); + size_t NumElements(std::vector& dims) { + return std::accumulate(dims.begin(), dims.end(), size_t(1), + std::multiplies()); } xnnpack::ReplicableRandomDevice rng; @@ -133,15 +128,34 @@ class FullyConnectedTestBase xnnpack::Buffer accumulators; }; -class FullyConnectedTestQP8F32QC4W - : public FullyConnectedTestBase {}; +template +class QuantizedFullyConnectedTestBase + : public FullyConnectedTestBase { + protected: + void initialize_accumulators_from_bias() { + for (size_t i = 0; i < this->batch_size; i++) { + for (size_t oc = 0; oc < this->output_channels; oc++) { + this->accumulators[i * this->output_channels + oc] = this->bias[oc]; + } + } + } +}; + +class FullyConnectedTestF32QC4W + : public FullyConnectedTestBase {}; -TEST_P(FullyConnectedTestQP8F32QC4W, define) { - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); +class FullyConnectedTestF32QC8W + : public FullyConnectedTestBase {}; - if (xnn_init_qp8_f32_qc4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } +using FullyConnectedTestQC8 = QuantizedFullyConnectedTestBase; +using FullyConnectedTestQS8 = QuantizedFullyConnectedTestBase; +using FullyConnectedTestQU8 = QuantizedFullyConnectedTestBase; +using FullyConnectedTestF16 = FullyConnectedTestBase; +using FullyConnectedTestF32 = FullyConnectedTestBase; +using DynamicFullyConnectedTestF32 = FullyConnectedTestBase; + +TEST_F(FullyConnectedTestQC8, define) { + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); @@ -150,42 +164,33 @@ TEST_P(FullyConnectedTestQP8F32QC4W, define) { uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qpint8, input_dims.size(), - /*num_nonbatch_dims=*/1, input_dims.data(), + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, 0, 1.0f, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - // Adjust number of kernel elements for QC4W. input_channels should be padded - // to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); - - const uint8_t kernel_zero_point = GetParam().kernel_zero_point; - xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), - [&]() { return scale_dist(rng); }); + xnnpack::Buffer scale(output_channels, 1.0f); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, - kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, scale.data(), kernel_dims.size(), + 0, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), - bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint32, scale.data(), bias_dims.size(), + 0, bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), - output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, 0, 1.0f, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( @@ -207,879 +212,165 @@ TEST_P(FullyConnectedTestQP8F32QC4W, define) { ASSERT_EQ(node->flags, 0); } -TEST_P(FullyConnectedTestQP8F32QC4W, matches_qd8_f32_qc4w) { +TEST_F(FullyConnectedTestQS8, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - if (xnn_init_qp8_f32_qc4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, - /*flags=*/0, &subgraph)); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); std::unique_ptr auto_subgraph( subgraph, xnn_delete_subgraph); - xnnpack::Buffer convert_input(batch_size * input_channels + - XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_qp8_data( - xnn_x8_packq_f32qp8_gemm_packed_size(batch_size, input_channels) + - XNN_EXTRA_BYTES); - xnnpack::Buffer operator_qd8_data(batch_size * input_channels + - XNN_EXTRA_BYTES); - xnnpack::Buffer qp8_operator_output(batch_size * output_channels); - xnnpack::Buffer qd8_operator_output(batch_size * output_channels); - - // Adjust number of kernel elements for QC4W. input_channels should be padded - // to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); - - xnnpack::Buffer kernel_scale(output_channels); - xnnpack::Buffer quantization_params( - batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - std::generate(kernel_scale.begin(), kernel_scale.end(), - [&]() { return scale_dist(rng); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), - [&]() { return f32dist(rng); }); - - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); - const uint8_t kernel_zero_point = GetParam().kernel_zero_point; - - // Call operator API for `qp8`. - xnn_operator_t qp8_convert_op = nullptr; - xnn_operator_t qp8_fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qp8( - /*flags=*/0, &qp8_convert_op); - std::unique_ptr - auto_qp8_convert_op(qp8_convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qp8_convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8( - qp8_convert_op, batch_size, input_channels, - input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qp8(qp8_convert_op, convert_input.data(), - operator_qp8_data.data())); + uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_run_operator(qp8_convert_op, /*threadpool=*/nullptr)); - - status = xnn_create_fully_connected_nc_qp8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &qp8_fc_op); - std::unique_ptr auto_qp8_fc_op( - qp8_fc_op, xnn_delete_operator); - - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, 0, 1.0f, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qp8_fc_op); - ASSERT_EQ(xnn_status_success, - xnn_reshape_fully_connected_nc_qp8_f32_qc4w( - qp8_fc_op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qp8_f32_qc4w( - qp8_fc_op, operator_qp8_data.data(), - qp8_operator_output.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(qp8_fc_op, /*threadpool=*/nullptr)); - - // Call operator API for `qd8`. - xnn_operator_t qd8_convert_op = nullptr; - xnn_operator_t qd8_fc_op = nullptr; - status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &qd8_convert_op); - std::unique_ptr - auto_qd8_convert_op(qd8_convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qd8_convert_op); - ASSERT_EQ(xnn_status_success, - xnn_reshape_convert_nc_f32_qd8( - qd8_convert_op, batch_size, input_channels, input_channels, - input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qd8(qd8_convert_op, convert_input.data(), - operator_qd8_data.data(), - quantization_params.data())); + uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_run_operator(qd8_convert_op, /*threadpool=*/nullptr)); - - status = xnn_create_fully_connected_nc_qd8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &qd8_fc_op); - std::unique_ptr auto_qd8_fc_op( - qd8_fc_op, xnn_delete_operator); - - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, 0, 1.0f, kernel_dims.size(), + kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qd8_fc_op); - ASSERT_EQ(xnn_status_success, - xnn_reshape_fully_connected_nc_qd8_f32_qc4w( - qd8_fc_op, batch_size, /*threadpool=*/nullptr)); + uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qc4w( - qd8_fc_op, operator_qd8_data.data(), qd8_operator_output.data(), - quantization_params.data())); + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint32, 0, 1.0f, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + + uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_run_operator(qd8_fc_op, /*threadpool=*/nullptr)); + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, 0, 1.0f, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); + ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - // Compare the outputs. Note that the values will not be exactly the same - // since the `qd8` quantization rounds to zero, whereas the `qp8` quantization - // does not. - float max_abs_val = 0.0f; - for (size_t i = 0; i < batch_size * output_channels; i++) { - max_abs_val = std::max(max_abs_val, std::abs(qd8_operator_output[i])); - } - for (size_t i = 0; i < batch_size * output_channels; i++) { - ASSERT_NEAR(qp8_operator_output[i], qd8_operator_output[i], - max_abs_val * 1e-2); - } + ASSERT_EQ( + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); + + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_fully_connected); + ASSERT_EQ(node->activation.output_min, output_min); + ASSERT_EQ(node->activation.output_max, output_max); + ASSERT_EQ(node->num_inputs, 3); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->inputs[1], kernel_id); + ASSERT_EQ(node->inputs[2], bias_id); + ASSERT_EQ(node->num_outputs, 1); + ASSERT_EQ(node->outputs[0], output_id); + ASSERT_EQ(node->flags, 0); } -TEST_P(FullyConnectedTestQP8F32QC4W, matches_operator_api) { +TEST_F(FullyConnectedTestQU8, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - if (xnn_init_qp8_f32_qc4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, - /*flags=*/0, &subgraph)); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); std::unique_ptr auto_subgraph( subgraph, xnn_delete_subgraph); - uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + - XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data( - xnn_x8_packq_f32qp8_gemm_packed_size(batch_size, input_channels)); - xnnpack::Buffer subgraph_output(batch_size * output_channels); - xnnpack::Buffer operator_output(batch_size * output_channels); - - // Adjust number of kernel elements for QC4W. input_channels should be padded - // to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); - xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), - [&]() { return scale_dist(rng); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), - [&]() { return f32dist(rng); }); + uint32_t input_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, 0, 1.0f, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); + uint32_t kernel_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, 0, 1.0f, kernel_dims.size(), + kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); - const uint8_t kernel_zero_point = GetParam().kernel_zero_point; + uint32_t bias_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint32, 0, 1.0f, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); - // Call operator API. - xnn_operator_t convert_op = nullptr; - xnn_operator_t fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qp8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op( - convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8( - convert_op, batch_size, input_channels, - input_channels, /*threadpool=*/nullptr)); + uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qp8(convert_op, convert_input.data(), - operator_dq_data.data())); + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, 0, 1.0f, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); + ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); + ASSERT_EQ(xnn_status_success, - xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + xnn_define_fully_connected(subgraph, output_min, output_max, + input_id, kernel_id, bias_id, output_id, + /*flags=*/0)); - status = xnn_create_fully_connected_nc_qp8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op( - fc_op, xnn_delete_operator); + ASSERT_EQ(subgraph->num_nodes, 1); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_fully_connected); + ASSERT_EQ(node->activation.output_min, output_min); + ASSERT_EQ(node->activation.output_max, output_max); + ASSERT_EQ(node->num_inputs, 3); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->inputs[1], kernel_id); + ASSERT_EQ(node->inputs[2], bias_id); + ASSERT_EQ(node->num_outputs, 1); + ASSERT_EQ(node->outputs[0], output_id); + ASSERT_EQ(node->flags, 0); +} - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } +TEST_P(FullyConnectedTestF16, define) { + bool use_bias = GetParam().use_bias; + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qc4w( - fc_op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qp8_f32_qc4w( - fc_op, operator_dq_data.data(), operator_output.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); - // Call subgraph API. - ASSERT_EQ(xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), - input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_NODE_ID); + uint32_t input_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), - /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); - ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ(xnn_status_success, - xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, - kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + kernel_dims.size(), kernel_dims.data(), + kernel.data(), /*external_id=*/1, + /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; + if (use_bias) { + ASSERT_EQ( + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + } + + uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( xnn_status_success, - xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), - bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), - output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, - &output_id)); - ASSERT_NE(output_id, XNN_INVALID_NODE_ID); - - ASSERT_EQ(xnn_status_success, - xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, - /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_GEMM)); - ASSERT_EQ(xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, - dq_quantized_id, kernel_id, bias_id, - output_id, /*flags=*/0)); - - // Make sure the quantized inputs were coerced to `qpint8`. - ASSERT_EQ(subgraph->num_nodes, 2); - const struct xnn_node* fc_node = &subgraph->nodes[1]; - ASSERT_EQ(fc_node->type, xnn_node_type_fully_connected); - - xnn_runtime_t runtime = nullptr; - ASSERT_EQ( - xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); - ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime( - runtime, xnn_delete_runtime); - std::array external = { - xnn_external_value{input_id, convert_input.data()}, - xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, - xnn_setup_runtime(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - - EXPECT_EQ(subgraph_output, operator_output); -} - -TEST_P(FullyConnectedTestQP8F32QC4W, matches_operator_api_with_reshape) { - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - if (xnn_init_qp8_f32_qc4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, - /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph( - subgraph, xnn_delete_subgraph); - uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(5 * batch_size * input_channels + - XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer subgraph_output(5 * batch_size * output_channels); - xnnpack::Buffer operator_dq_data( - xnn_x8_packq_f32qp8_gemm_packed_size(batch_size, input_channels) + - XNN_EXTRA_BYTES); - xnnpack::Buffer operator_output(5 * batch_size * output_channels); - - // These must be initialized due to the design of the test, which assumes - // unwritten portions of these buffers are matching. - std::fill(convert_input.begin(), convert_input.end(), 0.0f); - std::fill(subgraph_output.begin(), subgraph_output.end(), 0.0f); - std::fill(operator_output.begin(), operator_output.end(), 0.0f); - - // Adjust number of kernel elements for QC4W. input_channels should be padded - // to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); - - xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), - [&]() { return scale_dist(rng); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), - [&]() { return f32dist(rng); }); - - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); - - const uint8_t kernel_zero_point = GetParam().kernel_zero_point; - - // Call operator API. - xnn_operator_t convert_op = nullptr; - xnn_operator_t fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qp8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op( - convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8( - convert_op, batch_size, input_channels, - input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qp8(convert_op, convert_input.data(), - operator_dq_data.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(convert_op, /*threadpool=*/nullptr)); - - status = xnn_create_fully_connected_nc_qp8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, kernel_scale.data(), kernel.data(), nullptr, - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op( - fc_op, xnn_delete_operator); - - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qc4w( - fc_op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qp8_f32_qc4w( - fc_op, operator_dq_data.data(), operator_output.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(fc_op, /*threadpool=*/nullptr)); - - // Call subgraph API. - // - // dim[0] size increments: - // - // 0..........2......3.......4 - // ^ ^ ^ ^ - // `..Input1 | | | - // Input2..' | | - // Subgraph.......' | - // Input3.................' - - std::vector subgraph_input_dims(input_dims); - std::vector subgraph_output_dims(output_dims); - subgraph_input_dims[0] += 3; - subgraph_output_dims[0] += 3; - - ASSERT_EQ(xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, subgraph_input_dims.size(), - subgraph_input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_NODE_ID); - - uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, subgraph_input_dims.size(), - /*num_nonbatch_dims=*/1, subgraph_input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); - ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, - kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, subgraph_output_dims.size(), - subgraph_output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, - &output_id)); - ASSERT_NE(output_id, XNN_INVALID_NODE_ID); - - xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, - xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, - /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_GEMM)); - ASSERT_EQ(xnn_status_success, - xnn_define_fully_connected( - subgraph, output_min, output_max, dq_quantized_id, kernel_id, - XNN_INVALID_NODE_ID, output_id, /*flags=*/0)); - - // Make sure the quantized inputs were coerced to `qpint8`. - ASSERT_EQ(subgraph->num_nodes, 2); - const struct xnn_node* fc_node = &subgraph->nodes[1]; - ASSERT_EQ(fc_node->type, xnn_node_type_fully_connected); - - ASSERT_EQ( - xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); - ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime( - runtime, xnn_delete_runtime); - - struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_convert); - - // 1st inference: lets start smaller than we planned memory for - std::array external = { - xnn_external_value{input_id, convert_input.data()}, - xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, - xnn_reshape_external_value(runtime, input_id, input_dims.size(), - input_dims.data())); - ASSERT_EQ(xnn_status_success, xnn_reshape_runtime(runtime)); - ASSERT_EQ(xnn_status_success, - xnn_setup_runtime_v2(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - EXPECT_EQ(subgraph_output, operator_output); - - // 2nd inference: The dq-params should be properly allocated to handle a - // resize without memory retrigger - input_dims[0] += 2; - ASSERT_EQ(xnn_status_success, - xnn_reshape_external_value(runtime, input_id, input_dims.size(), - input_dims.data())); - ASSERT_EQ(xnn_status_success, xnn_reshape_runtime(runtime)); - ASSERT_EQ(xnn_status_success, - xnn_setup_runtime_v2(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - - // 3rd inference: The dq-params should be properly allocated even with memory - // retrigger - input_dims[0] += 2; // +4 total - ASSERT_EQ(xnn_status_success, - xnn_reshape_external_value(runtime, input_id, input_dims.size(), - input_dims.data())); - ASSERT_EQ(xnn_status_success, xnn_reshape_runtime(runtime)); - ASSERT_EQ(xnn_status_success, - xnn_setup_runtime_v2(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); -} - -TEST_P(FullyConnectedTestQP8F32QC4W, matches_operator_api_transposed_weights) { - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - if (xnn_init_qp8_f32_qc4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, - /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph( - subgraph, xnn_delete_subgraph); - uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + - XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data( - xnn_x8_packq_f32qp8_gemm_packed_size(batch_size, input_channels) + - XNN_EXTRA_BYTES); - xnnpack::Buffer subgraph_output(batch_size * output_channels); - xnnpack::Buffer operator_output(batch_size * output_channels); - - // Adjust number of kernel elements for QC4W. input_channels should be padded - // to byte boundary, hence even. - const size_t rounded_output_channels = round_up_po2(output_channels, 2); - kernel = xnnpack::Buffer(input_channels * rounded_output_channels); - - xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), - [&]() { return scale_dist(rng); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), - [&]() { return f32dist(rng); }); - - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); - - const uint8_t kernel_zero_point = GetParam().kernel_zero_point; - - // Call operator API. - xnn_operator_t convert_op = nullptr; - xnn_operator_t fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qp8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op( - convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8( - convert_op, batch_size, input_channels, - input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qp8(convert_op, convert_input.data(), - operator_dq_data.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(convert_op, /*threadpool=*/nullptr)); - - status = xnn_create_fully_connected_nc_qp8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), - output_min, output_max, XNN_FLAG_TRANSPOSE_WEIGHTS, nullptr, nullptr, - &fc_op); - std::unique_ptr auto_fc_op( - fc_op, xnn_delete_operator); - - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qc4w( - fc_op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qp8_f32_qc4w( - fc_op, operator_dq_data.data(), operator_output.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(fc_op, /*threadpool=*/nullptr)); - - // Call subgraph API. - ASSERT_EQ(xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), - input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_NODE_ID); - - uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), - /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); - ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, - kernel_scale.data(), kernel_dims_tranposed.size(), - /*channel_dim=*/1, kernel_dims_tranposed.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), - bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ(xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), - output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, - &output_id)); - ASSERT_NE(output_id, XNN_INVALID_NODE_ID); - - xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, - xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, - /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_GEMM)); - ASSERT_EQ(xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, - dq_quantized_id, kernel_id, bias_id, - output_id, XNN_FLAG_TRANSPOSE_WEIGHTS)); - - // Make sure the quantized inputs were coerced to `qpint8`. - ASSERT_EQ(subgraph->num_nodes, 2); - const struct xnn_node* fc_node = &subgraph->nodes[1]; - ASSERT_EQ(fc_node->type, xnn_node_type_fully_connected); - - ASSERT_EQ( - xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); - ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime( - runtime, xnn_delete_runtime); - std::array external = { - xnn_external_value{input_id, convert_input.data()}, - xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, - xnn_setup_runtime(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - - EXPECT_EQ(subgraph_output, operator_output); -} - -INSTANTIATE_TEST_SUITE_P(FullyConnectedTestQP8F32QC4W, - FullyConnectedTestQP8F32QC4W, - testing::ValuesIn( - {FullyConnectedTestParam(true, 0), - FullyConnectedTestParam(true, 8)})); - -template class QuantizedFullyConnectedTestBase : public FullyConnectedTestBase { -protected: - void initialize_accumulators_from_bias() - { - for (size_t i = 0; i < this->batch_size; i++) { - for (size_t oc = 0; oc < this->output_channels; oc++) { - this->accumulators[i * this->output_channels + oc] = this->bias[oc]; - } - } - } -}; - -class FullyConnectedTestF32QC4W : public FullyConnectedTestBase { -}; - -class FullyConnectedTestF32QC8W : public FullyConnectedTestBase { -}; -class FullyConnectedTestQP8F32QB4W - : public FullyConnectedTestBase {}; - - -using FullyConnectedTestQC8 = QuantizedFullyConnectedTestBase; -using FullyConnectedTestQS8 = QuantizedFullyConnectedTestBase; -using FullyConnectedTestQU8 = QuantizedFullyConnectedTestBase; -using FullyConnectedTestF16 = FullyConnectedTestBase; -using FullyConnectedTestF32 = FullyConnectedTestBase; -using DynamicFullyConnectedTestF32 = FullyConnectedTestBase; - -TEST_F(FullyConnectedTestQC8, define) -{ - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - - uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - - xnnpack::Buffer scale(output_channels, 1.0f); - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, scale.data(), kernel_dims.size(), 0, kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint32, scale.data(), bias_dims.size(), 0, bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - - uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); - - ASSERT_EQ(subgraph->num_nodes, 1); - const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_fully_connected); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); - ASSERT_EQ(node->num_inputs, 3); - ASSERT_EQ(node->inputs[0], input_id); - ASSERT_EQ(node->inputs[1], kernel_id); - ASSERT_EQ(node->inputs[2], bias_id); - ASSERT_EQ(node->num_outputs, 1); - ASSERT_EQ(node->outputs[0], output_id); - ASSERT_EQ(node->flags, 0); -} - -TEST_F(FullyConnectedTestQS8, define) -{ - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - - uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, 0, 1.0f, kernel_dims.size(), kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint32, 0, 1.0f, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - - uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); - - ASSERT_EQ(subgraph->num_nodes, 1); - const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_fully_connected); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); - ASSERT_EQ(node->num_inputs, 3); - ASSERT_EQ(node->inputs[0], input_id); - ASSERT_EQ(node->inputs[1], kernel_id); - ASSERT_EQ(node->inputs[2], bias_id); - ASSERT_EQ(node->num_outputs, 1); - ASSERT_EQ(node->outputs[0], output_id); - ASSERT_EQ(node->flags, 0); -} - -TEST_F(FullyConnectedTestQU8, define) -{ - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - - uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, 0, 1.0f, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, 0, 1.0f, kernel_dims.size(), kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint32, 0, 1.0f, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - - uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, 0, 1.0f, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - - ASSERT_EQ( - xnn_status_success, xnn_define_fully_connected( - subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, - /*flags=*/0)); - - ASSERT_EQ(subgraph->num_nodes, 1); - const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_fully_connected); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); - ASSERT_EQ(node->num_inputs, 3); - ASSERT_EQ(node->inputs[0], input_id); - ASSERT_EQ(node->inputs[1], kernel_id); - ASSERT_EQ(node->inputs[2], bias_id); - ASSERT_EQ(node->num_outputs, 1); - ASSERT_EQ(node->outputs[0], output_id); - ASSERT_EQ(node->flags, 0); -} - -TEST_P(FullyConnectedTestF16, define) -{ - bool use_bias = GetParam().use_bias; - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - - uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), /*external_id=*/1, - /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - if (use_bias) { - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - } - - uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1095,44 +386,48 @@ TEST_P(FullyConnectedTestF16, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestF32, define) -{ +TEST_F(FullyConnectedTestF32, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), /*external_id=*/1, - /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + kernel_dims.size(), kernel_dims.data(), + kernel.data(), /*external_id=*/1, + /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1148,45 +443,51 @@ TEST_F(FullyConnectedTestF32, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestF32QC4W, define) -{ +TEST_F(FullyConnectedTestF32QC4W, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, /*zero_point=*/8, requantization_scales.data(), - kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), kernel.data(), /*external_id=*/1, - /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, /*zero_point=*/8, + requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, + kernel_dims.data(), kernel.data(), /*external_id=*/1, + /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1202,40 +503,44 @@ TEST_F(FullyConnectedTestF32QC4W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestF32QC4W, define_without_bias) -{ +TEST_F(FullyConnectedTestF32QC4W, define_without_bias) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, /*zero_point=*/8, requantization_scales.data(), - kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), kernel.data(), /*external_id=*/1, - /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, /*zero_point=*/8, + requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, + kernel_dims.data(), kernel.data(), /*external_id=*/1, + /*flags=*/0, &kernel_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected( - subgraph, output_min, output_max, input_id, kernel_id, XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected( + subgraph, output_min, output_max, input_id, kernel_id, + XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1251,44 +556,49 @@ TEST_F(FullyConnectedTestF32QC4W, define_without_bias) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestF32QC8W, define) -{ +TEST_F(FullyConnectedTestF32QC8W, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1304,38 +614,42 @@ TEST_F(FullyConnectedTestF32QC8W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestF32QC8W, define_without_bias) -{ +TEST_F(FullyConnectedTestF32QC8W, define_without_bias) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected( + subgraph, output_min, output_max, input_id, kernel_id, + XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1351,44 +665,47 @@ TEST_F(FullyConnectedTestF32QC8W, define_without_bias) ASSERT_EQ(node->flags, 0); } -TEST_F(DynamicFullyConnectedTestF32, define) -{ +TEST_F(DynamicFullyConnectedTestF32, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), nullptr, /*external_id=*/1, - /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), nullptr, /*external_id=*/1, + /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), nullptr, - /*external_id=*/2, /*flags=*/0, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + bias_dims.size(), bias_dims.data(), nullptr, + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -1404,8 +721,7 @@ TEST_F(DynamicFullyConnectedTestF32, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQC8, matches_operator_api) -{ +TEST_F(FullyConnectedTestQC8, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -1416,7 +732,8 @@ TEST_F(FullyConnectedTestQC8, matches_operator_api) const int8_t input_zero_point = -1; const float input_scale = scale_dist(rng); xnnpack::Buffer requantization_scales(output_channels, 1.0f); - std::generate(requantization_scales.begin(), requantization_scales.end(), [&]() { return f32dist(rng); }); + std::generate(requantization_scales.begin(), requantization_scales.end(), + [&]() { return f32dist(rng); }); // Compute reference results, without renormalization. initialize_accumulators_from_bias(); @@ -1424,32 +741,43 @@ TEST_F(FullyConnectedTestQC8, matches_operator_api) for (size_t oc = 0; oc < output_channels; oc++) { for (size_t ic = 0; ic < input_channels; ic++) { accumulators[i * output_channels + oc] += - (int32_t(input[i * input_channels + ic]) - int32_t(input_zero_point)) * - int32_t(kernel[oc * input_channels + ic]); + (static_cast(input[i * input_channels + ic]) - + static_cast(input_zero_point)) * + static_cast(kernel[oc * input_channels + ic]); } } } // Compute renormalization parameters. - const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); - const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); - - float output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; - int8_t output_zero_point = int8_t(std::max( - std::min( - lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), - long(std::numeric_limits::max())), - long(std::numeric_limits::min()))); - const int8_t quantized_output_min = xnn_qs8_quantize(output_min, output_scale, output_zero_point); - const int8_t quantized_output_max = xnn_qs8_quantize(output_max, output_scale, output_zero_point); + const int32_t accumulated_min = + *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulated_max = + *std::max_element(accumulators.cbegin(), accumulators.cend()); + + float output_scale = static_cast(static_cast( + accumulated_max - accumulated_min)) / + 255.0; + int8_t output_zero_point = static_cast(std::max( + std::min( + lrint(-0.5 - + 0.5 * static_cast(accumulated_min + accumulated_max) / + output_scale), + static_cast(std::numeric_limits::max())), + static_cast(std::numeric_limits::min()))); + const int8_t quantized_output_min = + xnn_qs8_quantize(output_min, output_scale, output_zero_point); + const int8_t quantized_output_max = + xnn_qs8_quantize(output_max, output_scale, output_zero_point); // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_qs8_qc8w( - input_channels, output_channels, input_channels, output_channels, input_zero_point, input_scale, - requantization_scales.data(), kernel.data(), - bias.data(), output_zero_point, output_scale, quantized_output_min, quantized_output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + input_zero_point, input_scale, requantization_scales.data(), + kernel.data(), bias.data(), output_zero_point, output_scale, + quantized_output_min, quantized_output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1457,61 +785,72 @@ TEST_F(FullyConnectedTestQC8, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qs8_qc8w(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qs8_qc8w(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qs8_qc8w( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qs8_qc8w( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, input_zero_point, input_scale, input_dims.size(), - input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, input_zero_point, input_scale, + input_dims.size(), input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, - requantization_scales.data(), kernel_dims.size(), 0, kernel_dims.data(), - kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), 0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint32, requantization_scales.data(), bias_dims.size(), 0, bias_dims.data(), - bias.data(), /*external_id=*/2, /*flags=*/0, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint32, requantization_scales.data(), + bias_dims.size(), 0, bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(), - output_dims.data(), nullptr, /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, output_zero_point, output_scale, + output_dims.size(), output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestQS8, matches_operator_api) -{ +TEST_F(FullyConnectedTestQS8, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -1529,31 +868,43 @@ TEST_F(FullyConnectedTestQS8, matches_operator_api) for (size_t oc = 0; oc < output_channels; oc++) { for (size_t ic = 0; ic < input_channels; ic++) { accumulators[i * output_channels + oc] += - (int32_t(input[i * input_channels + ic]) - int32_t(input_zero_point)) * - int32_t(kernel[oc * input_channels + ic]); + (static_cast(input[i * input_channels + ic]) - + static_cast(input_zero_point)) * + static_cast(kernel[oc * input_channels + ic]); } } } // Compute renormalization parameters. - const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); - const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); - - float output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; - int8_t output_zero_point = int8_t(std::max( - std::min( - lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), - long(std::numeric_limits::max())), - long(std::numeric_limits::min()))); - const int8_t quantized_output_min = xnn_qs8_quantize(output_min, output_scale, output_zero_point); - const int8_t quantized_output_max = xnn_qs8_quantize(output_max, output_scale, output_zero_point); + const int32_t accumulated_min = + *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulated_max = + *std::max_element(accumulators.cbegin(), accumulators.cend()); + + float output_scale = static_cast(static_cast( + accumulated_max - accumulated_min)) / + 255.0; + int8_t output_zero_point = static_cast(std::max( + std::min( + lrint(-0.5 - + 0.5 * static_cast(accumulated_min + accumulated_max) / + output_scale), + static_cast(std::numeric_limits::max())), + static_cast(std::numeric_limits::min()))); + const int8_t quantized_output_min = + xnn_qs8_quantize(output_min, output_scale, output_zero_point); + const int8_t quantized_output_max = + xnn_qs8_quantize(output_max, output_scale, output_zero_point); // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_qs8( - input_channels, output_channels, input_channels, output_channels, input_zero_point, input_scale, kernel_scale, - kernel.data(), bias.data(), output_zero_point, output_scale, quantized_output_min, quantized_output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + input_zero_point, input_scale, kernel_scale, kernel.data(), bias.data(), + output_zero_point, output_scale, quantized_output_min, + quantized_output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1561,60 +912,72 @@ TEST_F(FullyConnectedTestQS8, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qs8(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qs8(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qs8( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qs8( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, input_zero_point, input_scale, input_dims.size(), - input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, input_zero_point, input_scale, + input_dims.size(), input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, 0, kernel_scale, kernel_dims.size(), kernel_dims.data(), - kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, 0, kernel_scale, + kernel_dims.size(), kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint32, 0, kernel_scale, bias_dims.size(), bias_dims.data(), - bias.data(), /*external_id=*/2, /*flags=*/0, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint32, 0, kernel_scale, + bias_dims.size(), bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint8, output_zero_point, output_scale, output_dims.size(), - output_dims.data(), nullptr, /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint8, output_zero_point, output_scale, + output_dims.size(), output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestQU8, matches_operator_api) -{ +TEST_F(FullyConnectedTestQU8, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -1633,31 +996,44 @@ TEST_F(FullyConnectedTestQU8, matches_operator_api) for (size_t oc = 0; oc < output_channels; oc++) { for (size_t ic = 0; ic < input_channels; ic++) { accumulators[i * output_channels + oc] += - (int32_t(input[i * input_channels + ic]) - int32_t(input_zero_point)) * - (int32_t(kernel[oc * input_channels + ic]) - int32_t(kernel_zero_point)); + (static_cast(input[i * input_channels + ic]) - + static_cast(input_zero_point)) * + (static_cast(kernel[oc * input_channels + ic]) - + static_cast(kernel_zero_point)); } } } // Compute renormalization parameters. - const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); - const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); - - const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; - const uint8_t output_zero_point = uint8_t(std::max( - std::min( - lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), - long(std::numeric_limits::max())), - long(std::numeric_limits::min()))); - const uint8_t quantized_output_min = xnn_qu8_quantize(output_min, output_scale, output_zero_point); - const uint8_t quantized_output_max = xnn_qu8_quantize(output_max, output_scale, output_zero_point); + const int32_t accumulated_min = + *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulated_max = + *std::max_element(accumulators.cbegin(), accumulators.cend()); + + const double output_scale = static_cast(static_cast( + accumulated_max - accumulated_min)) / + 255.0; + const uint8_t output_zero_point = static_cast(std::max( + std::min( + lrint(127.5 - + 0.5 * static_cast(accumulated_min + accumulated_max) / + output_scale), + static_cast(std::numeric_limits::max())), + static_cast(std::numeric_limits::min()))); + const uint8_t quantized_output_min = + xnn_qu8_quantize(output_min, output_scale, output_zero_point); + const uint8_t quantized_output_max = + xnn_qu8_quantize(output_max, output_scale, output_zero_point); // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_qu8( - input_channels, output_channels, input_channels, output_channels, input_zero_point, input_scale, kernel_zero_point, - kernel_scale, kernel.data(), bias.data(), output_zero_point, output_scale, quantized_output_min, - quantized_output_max, /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + input_zero_point, input_scale, kernel_zero_point, kernel_scale, + kernel.data(), bias.data(), output_zero_point, output_scale, + quantized_output_min, quantized_output_max, /*flags=*/0, nullptr, nullptr, + &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1665,60 +1041,72 @@ TEST_F(FullyConnectedTestQU8, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qu8(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qu8(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qu8( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qu8( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, input_zero_point, input_scale, input_dims.size(), - input_dims.data(), nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, input_zero_point, input_scale, + input_dims.size(), input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, 0, kernel_scale, kernel_dims.size(), kernel_dims.data(), - kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, 0, kernel_scale, + kernel_dims.size(), kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_qint32, 0, kernel_scale, bias_dims.size(), bias_dims.data(), - bias.data(), /*external_id=*/2, /*flags=*/0, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_qint32, 0, kernel_scale, + bias_dims.size(), bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_quantized_tensor_value( - subgraph, xnn_datatype_quint8, output_zero_point, output_scale, output_dims.size(), - output_dims.data(), nullptr, /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_quantized_tensor_value( + subgraph, xnn_datatype_quint8, output_zero_point, output_scale, + output_dims.size(), output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_P(FullyConnectedTestF16, matches_operator_api) -{ +TEST_P(FullyConnectedTestF16, matches_operator_api) { bool use_bias = GetParam().use_bias; ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); @@ -1732,9 +1120,11 @@ TEST_P(FullyConnectedTestF16, matches_operator_api) // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f16( - input_channels, output_channels, input_channels, output_channels, kernel.data(), use_bias ? bias.data() : nullptr, output_min, - output_max, XNN_FLAG_FP32_STATIC_WEIGHTS, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel.data(), use_bias ? bias.data() : nullptr, output_min, output_max, + XNN_FLAG_FP32_STATIC_WEIGHTS, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1742,54 +1132,67 @@ TEST_P(FullyConnectedTestF16, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f16(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f16(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f16( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f16( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; if (use_bias) { ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); } uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. @@ -1801,24 +1204,28 @@ INSTANTIATE_TEST_SUITE_P(UseBias, FullyConnectedTestF16, {FullyConnectedTestParam(false), FullyConnectedTestParam(true)})); -TEST_P(FullyConnectedTestF16, matches_operator_api_f16_weights) -{ +TEST_P(FullyConnectedTestF16, matches_operator_api_f16_weights) { bool use_bias = GetParam().use_bias; ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; - std::generate(input.begin(), input.end(), [&]() { return xnn_float16_from_float(f32dist(rng)); }); - std::generate(kernel_fp16.begin(), kernel_fp16.end(), [&]() { return xnn_float16_from_float(f32dist(rng)); }); + std::generate(input.begin(), input.end(), + [&]() { return xnn_float16_from_float(f32dist(rng)); }); + std::generate(kernel_fp16.begin(), kernel_fp16.end(), + [&]() { return xnn_float16_from_float(f32dist(rng)); }); if (use_bias) { - std::generate(bias_fp16.begin(), bias_fp16.end(), [&]() { return xnn_float16_from_float(f32dist(rng)); }); + std::generate(bias_fp16.begin(), bias_fp16.end(), + [&]() { return xnn_float16_from_float(f32dist(rng)); }); } // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f16( - input_channels, output_channels, input_channels, output_channels, kernel_fp16.data(), use_bias ? bias_fp16.data() : nullptr, output_min, - output_max, /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_fp16.data(), use_bias ? bias_fp16.data() : nullptr, output_min, + output_max, /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1826,62 +1233,74 @@ TEST_P(FullyConnectedTestF16, matches_operator_api_f16_weights) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f16(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f16(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f16( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f16( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, kernel_dims.size(), kernel_dims.data(), kernel_fp16.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, kernel_dims.size(), + kernel_dims.data(), kernel_fp16.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; if (use_bias) { ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, bias_dims.size(), bias_dims.data(), bias_fp16.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, bias_dims.size(), + bias_dims.data(), bias_fp16.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); } uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32, matches_operator_api) -{ +TEST_F(FullyConnectedTestF32, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -1892,10 +1311,11 @@ TEST_F(FullyConnectedTestF32, matches_operator_api) // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f32( - input_channels, output_channels, input_channels, output_channels, kernel.data(), bias.data(), output_min, - output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel.data(), bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1903,81 +1323,92 @@ TEST_F(FullyConnectedTestF32, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32QC4W, matches_operator_api) -{ +TEST_F(FullyConnectedTestF32QC4W, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); std::generate(kernel.begin(), kernel.end(), [&]() { return i8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); xnnpack::Buffer requantization_scales(output_channels); - std::generate(requantization_scales.begin(), requantization_scales.end(), [&]() { return scale_dist(rng); }); + std::generate(requantization_scales.begin(), requantization_scales.end(), + [&]() { return scale_dist(rng); }); uint8_t kernel_zero_point = 8; // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, requantization_scales.data(), kernel.data(), bias.data(), - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_zero_point, requantization_scales.data(), kernel.data(), + bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -1985,81 +1416,92 @@ TEST_F(FullyConnectedTestF32QC4W, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc4w(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc4w(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc4w( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc4w( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + requantization_scales.data(), kernel_dims.size(), + /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32QC4W, matches_operator_api_without_bias) -{ +TEST_F(FullyConnectedTestF32QC4W, matches_operator_api_without_bias) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); std::generate(kernel.begin(), kernel.end(), [&]() { return i8dist(rng); }); xnnpack::Buffer requantization_scales(output_channels); - std::generate(requantization_scales.begin(), requantization_scales.end(), [&]() { return scale_dist(rng); }); + std::generate(requantization_scales.begin(), requantization_scales.end(), + [&]() { return scale_dist(rng); }); uint8_t kernel_zero_point = 8; // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, - kernel_zero_point, requantization_scales.data(), kernel.data(), nullptr, - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_zero_point, requantization_scales.data(), kernel.data(), nullptr, + output_min, output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2067,55 +1509,66 @@ TEST_F(FullyConnectedTestF32QC4W, matches_operator_api_without_bias) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc4w(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc4w(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc4w( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc4w( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + requantization_scales.data(), kernel_dims.size(), + /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected( + subgraph, output_min, output_max, input_id, kernel_id, + XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32QC8W, matches_operator_api) -{ +TEST_F(FullyConnectedTestF32QC8W, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -2124,14 +1577,17 @@ TEST_F(FullyConnectedTestF32QC8W, matches_operator_api) std::generate(kernel.begin(), kernel.end(), [&]() { return i8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); xnnpack::Buffer requantization_scales(output_channels); - std::generate(requantization_scales.begin(), requantization_scales.end(), [&]() { return scale_dist(rng); }); + std::generate(requantization_scales.begin(), requantization_scales.end(), + [&]() { return scale_dist(rng); }); // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f32_qc8w( - input_channels, output_channels, input_channels, output_channels, requantization_scales.data(), kernel.data(), bias.data(), output_min, - output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + requantization_scales.data(), kernel.data(), bias.data(), output_min, + output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2139,61 +1595,73 @@ TEST_F(FullyConnectedTestF32QC8W, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc8w(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc8w(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc8w( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc8w( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_without_bias) -{ +TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_without_bias) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -2201,14 +1669,17 @@ TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_without_bias) std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return i8dist(rng); }); xnnpack::Buffer requantization_scales(output_channels); - std::generate(requantization_scales.begin(), requantization_scales.end(), [&]() { return scale_dist(rng); }); + std::generate(requantization_scales.begin(), requantization_scales.end(), + [&]() { return scale_dist(rng); }); // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f32_qc8w( - input_channels, output_channels, input_channels, output_channels, requantization_scales.data(), kernel.data(), - /*bias=*/nullptr, output_min, output_max, - /*flags=*/0, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + requantization_scales.data(), kernel.data(), + /*bias=*/nullptr, output_min, output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2216,48 +1687,59 @@ TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_without_bias) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc8w(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc8w(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc8w( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc8w( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected( - subgraph, output_min, output_max, input_id, kernel_id, XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected( + subgraph, output_min, output_max, input_id, kernel_id, + XNN_INVALID_VALUE_ID, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. @@ -2269,8 +1751,7 @@ INSTANTIATE_TEST_SUITE_P(UseBias, DynamicFullyConnectedTestF32, {FullyConnectedTestParam(false), FullyConnectedTestParam(true)})); -TEST_P(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel) -{ +TEST_P(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel) { bool use_bias = GetParam().use_bias; ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); @@ -2283,8 +1764,10 @@ TEST_P(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel) } // Call operator API. - const xnn_status status = xnn_create_dynamic_fully_connected_nc_f32(output_min, output_max, /*flags=*/0, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + const xnn_status status = xnn_create_dynamic_fully_connected_nc_f32( + output_min, output_max, /*flags=*/0, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2295,74 +1778,85 @@ TEST_P(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel) size_t workspace_size = 0; size_t workspace_alignment = 0; - ASSERT_EQ( - xnn_status_success, xnn_reshape_dynamic_fully_connected_nc_f32( - op, batch_size, input_channels, output_channels, input_channels, output_channels, - &workspace_size, &workspace_alignment, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_dynamic_fully_connected_nc_f32( + op, batch_size, input_channels, output_channels, input_channels, + output_channels, &workspace_size, &workspace_alignment, + /*threadpool=*/nullptr)); ASSERT_NE(workspace_size, 0); ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); xnnpack::Buffer workspace(workspace_size); - ASSERT_EQ( - xnn_status_success, xnn_setup_dynamic_fully_connected_nc_f32( - op, workspace.data(), input.data(), kernel.data(), use_bias ? bias.data() : nullptr, operator_output.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_dynamic_fully_connected_nc_f32( + op, workspace.data(), input.data(), kernel.data(), + use_bias ? bias.data() : nullptr, operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), nullptr, - /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), nullptr, + /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; if (use_bias) { ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); } uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, - xnn_external_value{kernel_id, kernel.data()}, - xnn_external_value{output_id, subgraph_output.data()}, + xnn_external_value{input_id, input.data()}, + xnn_external_value{kernel_id, kernel.data()}, + xnn_external_value{output_id, subgraph_output.data()}, }; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_bias) -{ +TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_bias) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -2372,8 +1866,10 @@ TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_bias) std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); // Call operator API. - const xnn_status status = xnn_create_dynamic_fully_connected_nc_f32(output_min, output_max, /*flags=*/0, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + const xnn_status status = xnn_create_dynamic_fully_connected_nc_f32( + output_min, output_max, /*flags=*/0, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2384,72 +1880,84 @@ TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_bias) size_t workspace_size = 0; size_t workspace_alignment = 0; - ASSERT_EQ( - xnn_status_success, xnn_reshape_dynamic_fully_connected_nc_f32( - op, batch_size, input_channels, output_channels, input_channels, output_channels, - &workspace_size, &workspace_alignment, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_dynamic_fully_connected_nc_f32( + op, batch_size, input_channels, output_channels, input_channels, + output_channels, &workspace_size, &workspace_alignment, + /*threadpool=*/nullptr)); ASSERT_NE(workspace_size, 0); ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); xnnpack::Buffer workspace(workspace_size); - ASSERT_EQ( - xnn_status_success, xnn_setup_dynamic_fully_connected_nc_f32( - op, workspace.data(), input.data(), kernel.data(), bias.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_dynamic_fully_connected_nc_f32( + op, workspace.data(), input.data(), kernel.data(), bias.data(), + operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), nullptr, - /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_INPUT, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + bias_dims.size(), bias_dims.data(), nullptr, + /*external_id=*/2, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, - xnn_external_value{bias_id, bias.data()}, - xnn_external_value{output_id, subgraph_output.data()}, + xnn_external_value{input_id, input.data()}, + xnn_external_value{bias_id, bias.data()}, + xnn_external_value{output_id, subgraph_output.data()}, }; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel_and_bias) -{ +TEST_F(DynamicFullyConnectedTestF32, + matches_operator_api_dynamic_kernel_and_bias) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -2459,8 +1967,10 @@ TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel_and_bia std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); // Call operator API. - const xnn_status status = xnn_create_dynamic_fully_connected_nc_f32(output_min, output_max, /*flags=*/0, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + const xnn_status status = xnn_create_dynamic_fully_connected_nc_f32( + output_min, output_max, /*flags=*/0, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2471,73 +1981,84 @@ TEST_F(DynamicFullyConnectedTestF32, matches_operator_api_dynamic_kernel_and_bia size_t workspace_size = 0; size_t workspace_alignment = 0; - ASSERT_EQ( - xnn_status_success, xnn_reshape_dynamic_fully_connected_nc_f32( - op, batch_size, input_channels, output_channels, input_channels, output_channels, - &workspace_size, &workspace_alignment, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_dynamic_fully_connected_nc_f32( + op, batch_size, input_channels, output_channels, input_channels, + output_channels, &workspace_size, &workspace_alignment, + /*threadpool=*/nullptr)); ASSERT_NE(workspace_size, 0); ASSERT_LE(workspace_alignment, XNN_ALLOCATION_ALIGNMENT); xnnpack::Buffer workspace(workspace_size); - ASSERT_EQ( - xnn_status_success, xnn_setup_dynamic_fully_connected_nc_f32( - op, workspace.data(), input.data(), kernel.data(), bias.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_dynamic_fully_connected_nc_f32( + op, workspace.data(), input.data(), kernel.data(), bias.data(), + operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), nullptr, - /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), nullptr, + /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), nullptr, - /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_INPUT, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + bias_dims.size(), bias_dims.data(), nullptr, + /*external_id=*/2, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, - xnn_external_value{kernel_id, kernel.data()}, - xnn_external_value{bias_id, bias.data()}, - xnn_external_value{output_id, subgraph_output.data()}, + xnn_external_value{input_id, input.data()}, + xnn_external_value{kernel_id, kernel.data()}, + xnn_external_value{bias_id, bias.data()}, + xnn_external_value{output_id, subgraph_output.data()}, }; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_transposed_weights) -{ +TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_transposed_weights) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_operator_t op = nullptr; @@ -2546,13 +2067,16 @@ TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_transposed_weights) std::generate(kernel.begin(), kernel.end(), [&]() { return i8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); xnnpack::Buffer requantization_scales(output_channels); - std::generate(requantization_scales.begin(), requantization_scales.end(), [&]() { return scale_dist(rng); }); + std::generate(requantization_scales.begin(), requantization_scales.end(), + [&]() { return scale_dist(rng); }); // Call operator API. const xnn_status status = xnn_create_fully_connected_nc_f32_qc8w( - input_channels, output_channels, input_channels, output_channels, requantization_scales.data(), kernel.data(), - bias.data(), output_min, output_max, XNN_FLAG_TRANSPOSE_WEIGHTS, nullptr, nullptr, &op); - std::unique_ptr auto_op(op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + requantization_scales.data(), kernel.data(), bias.data(), output_min, + output_max, XNN_FLAG_TRANSPOSE_WEIGHTS, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2560,268 +2084,306 @@ TEST_F(FullyConnectedTestF32QC8W, matches_operator_api_transposed_weights) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc8w(op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc8w(op, input.data(), operator_output.data())); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32_qc8w( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32_qc8w( + op, input.data(), operator_output.data())); ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); // Call subgraph API. xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims_tranposed.size(), - /*channel_dim=*/1, kernel_dims_tranposed.data(), kernel.data(), - /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims_tranposed.size(), + /*channel_dim=*/1, kernel_dims_tranposed.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected( - subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, XNN_FLAG_TRANSPOSE_WEIGHTS)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + input_id, kernel_id, bias_id, output_id, + XNN_FLAG_TRANSPOSE_WEIGHTS)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); // Check outputs match. EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -TEST_F(FullyConnectedTestF32QC8W, non_static_kernel_is_invalid_parameter) -{ +TEST_F(FullyConnectedTestF32QC8W, non_static_kernel_is_invalid_parameter) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), nullptr, /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + nullptr, /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_invalid_parameter, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_invalid_parameter, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); } -TEST_F(FullyConnectedTestF32QC8W, non_static_bias_is_invalid_parameter) -{ +TEST_F(FullyConnectedTestF32QC8W, non_static_bias_is_invalid_parameter) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), nullptr, - /*external_id=*/2, /*flags=*/0, &bias_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + bias_dims.size(), bias_dims.data(), nullptr, + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_invalid_parameter, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_invalid_parameter, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); } -TEST_F(FullyConnectedTestF32QC8W, invalid_channel_dimension) -{ +TEST_F(FullyConnectedTestF32QC8W, invalid_channel_dimension) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); const size_t channel_dim = 1; xnnpack::Buffer requantization_scales(kernel_dims[channel_dim], 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), channel_dim, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), channel_dim, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_invalid_parameter, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_invalid_parameter, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); } -TEST_F(FullyConnectedTestF32QC8W, transposed_weights_invalid_channel_dimension) -{ +TEST_F(FullyConnectedTestF32QC8W, + transposed_weights_invalid_channel_dimension) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); const size_t channel_dim = 0; - xnnpack::Buffer requantization_scales(kernel_dims_tranposed[channel_dim], 1.0f); + xnnpack::Buffer requantization_scales( + kernel_dims_tranposed[channel_dim], 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/0, &input_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims_tranposed.size(), channel_dim, - kernel_dims_tranposed.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims_tranposed.size(), channel_dim, + kernel_dims_tranposed.data(), kernel.data(), /*external_id=*/1, + /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - ASSERT_EQ( - xnn_status_invalid_parameter, - xnn_define_fully_connected( - subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, XNN_FLAG_TRANSPOSE_WEIGHTS)); + ASSERT_EQ(xnn_status_invalid_parameter, + xnn_define_fully_connected(subgraph, output_min, output_max, + input_id, kernel_id, bias_id, output_id, + XNN_FLAG_TRANSPOSE_WEIGHTS)); } -class FullyConnectedTestQD8F16QC4W : public FullyConnectedTestBase { -}; +class FullyConnectedTestQD8F16QC4W + : public FullyConnectedTestBase {}; -TEST_F(FullyConnectedTestQD8F16QC4W, define) -{ +TEST_F(FullyConnectedTestQD8F16QC4W, define) { xnnpack::Buffer requantization_scales(output_channels, 1.0f); ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); const uint8_t kernel_zero_point = 8; uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + requantization_scales.data(), kernel_dims.size(), + /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -2837,25 +2399,34 @@ TEST_F(FullyConnectedTestQD8F16QC4W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQD8F16QC4W, internally_allocated_dynamic_quantization_parameters) -{ +TEST_F(FullyConnectedTestQD8F16QC4W, + internally_allocated_dynamic_quantization_parameters) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(xnn_float16)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + xnnpack::Buffer convert_input( + batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(xnn_float16)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer subgraph_output(batch_size * output_channels); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -2865,23 +2436,32 @@ TEST_F(FullyConnectedTestQD8F16QC4W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f16_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f16_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f16_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f16_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f16_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f16_qc4w( - input_channels, output_channels, input_channels, output_channels, kernel_zero_point, kernel_scale.data(), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), + output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -2889,62 +2469,82 @@ TEST_F(FullyConnectedTestQD8F16QC4W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f16_qc4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f16_qc4w( + fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qd8_f16_qc4w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f16_qc4w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, kernel_scale.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); } -class FullyConnectedTestQD8F16QB4W : public FullyConnectedTestBase { +class FullyConnectedTestQD8F16QB4W + : public FullyConnectedTestBase { }; -TEST_F(FullyConnectedTestQD8F16QB4W, define) -{ +TEST_F(FullyConnectedTestQD8F16QB4W, define) { size_t block_size = 32; input_channels = round_up_po2(input_channels, block_size); @@ -2955,43 +2555,51 @@ TEST_F(FullyConnectedTestQD8F16QB4W, define) xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - // Adjust number of kernel elements for QB4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + // We adjusted input_channels above, reallocate the kernel. + kernel = xnnpack::Buffer(output_channels * input_channels); const uint8_t kernel_zero_point = 8; xnnpack::Buffer kernel_scale(output_channels * block_size); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_blockwise_quantized_tensor_value( - subgraph, xnn_datatype_qbint4, kernel_zero_point, reinterpret_cast(kernel_scale.data()), kernel_dims.size(), - /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_blockwise_quantized_tensor_value( + subgraph, xnn_datatype_qbint4, kernel_zero_point, + reinterpret_cast(kernel_scale.data()), + kernel_dims.size(), + /*channel_dim=*/0, block_size, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -3007,8 +2615,8 @@ TEST_F(FullyConnectedTestQD8F16QB4W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQD8F16QB4W, internally_allocated_dynamic_quantization_parameters) -{ +TEST_F(FullyConnectedTestQD8F16QB4W, + internally_allocated_dynamic_quantization_parameters) { size_t block_size = 32; input_channels = round_up_po2(input_channels, block_size); @@ -3017,25 +2625,33 @@ TEST_F(FullyConnectedTestQD8F16QB4W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); - xnnpack::Buffer subgraph_output(batch_size * output_channels); - xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + xnnpack::Buffer convert_input( + batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(xnn_float16)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); + xnnpack::Buffer subgraph_output(batch_size * output_channels); + xnnpack::Buffer operator_output(batch_size * output_channels); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels * block_size); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + // We adjusted input_channels above, reallocate the kernel. + kernel = xnnpack::Buffer(output_channels * input_channels); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -3046,23 +2662,33 @@ TEST_F(FullyConnectedTestQD8F16QB4W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f16_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f16_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f16_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f16_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f16_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f16_qb4w( - input_channels, output_channels, input_channels, output_channels, block_size, kernel_zero_point, reinterpret_cast(kernel_scale.data()), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + block_size, kernel_zero_point, + reinterpret_cast(kernel_scale.data()), kernel.data(), + bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -3070,97 +2696,123 @@ TEST_F(FullyConnectedTestQD8F16QB4W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f16_qb4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f16_qb4w( + fc_op, batch_size, /*threadpool=*/nullptr)); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f16_qb4w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_setup_fully_connected_nc_qd8_f16_qb4w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_blockwise_quantized_tensor_value( - subgraph, xnn_datatype_qbint4, kernel_zero_point, reinterpret_cast(kernel_scale.data()), kernel_dims.size(), - /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_blockwise_quantized_tensor_value( + subgraph, xnn_datatype_qbint4, kernel_zero_point, + reinterpret_cast(kernel_scale.data()), + kernel_dims.size(), + /*channel_dim=*/0, block_size, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); } -class FullyConnectedTestQD8F16QC8W : public FullyConnectedTestBase { -}; +class FullyConnectedTestQD8F16QC8W + : public FullyConnectedTestBase {}; -TEST_F(FullyConnectedTestQD8F16QC8W, define) -{ +TEST_F(FullyConnectedTestQD8F16QC8W, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -3176,25 +2828,34 @@ TEST_F(FullyConnectedTestQD8F16QC8W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQD8F16QC8W, internally_allocated_dynamic_quantization_parameters) -{ +TEST_F(FullyConnectedTestQD8F16QC8W, + internally_allocated_dynamic_quantization_parameters) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(xnn_float16)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + xnnpack::Buffer convert_input( + batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(xnn_float16)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer subgraph_output(batch_size * output_channels); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -3203,23 +2864,31 @@ TEST_F(FullyConnectedTestQD8F16QC8W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f16_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f16_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f16_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f16_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f16_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f16_qc8w( - input_channels, output_channels, input_channels, output_channels, kernel_scale.data(), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_scale.data(), kernel.data(), bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -3227,99 +2896,125 @@ TEST_F(FullyConnectedTestQD8F16QC8W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f16_qc8w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f16_qc8w( + fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qd8_f16_qc8w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f16_qc8w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, kernel_scale.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, kernel_scale.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp16, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp16, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); + for (size_t i = 0; i < operator_output.size(); i++) { + ASSERT_EQ(subgraph_output[i], operator_output[i]); + } } -class FullyConnectedTestQD8F32QC8W : public FullyConnectedTestBase { -}; +class FullyConnectedTestQD8F32QC8W + : public FullyConnectedTestBase {}; -TEST_F(FullyConnectedTestQD8F32QC8W, define) -{ +TEST_F(FullyConnectedTestQD8F32QC8W, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, requantization_scales.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, requantization_scales.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -3335,26 +3030,35 @@ TEST_F(FullyConnectedTestQD8F32QC8W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQD8F32QC8W, internally_allocated_dynamic_quantization_parameters) -{ +TEST_F(FullyConnectedTestQD8F32QC8W, + internally_allocated_dynamic_quantization_parameters) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + xnnpack::Buffer convert_input(batch_size * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer subgraph_output(batch_size * output_channels); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -3363,23 +3067,31 @@ TEST_F(FullyConnectedTestQD8F32QC8W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f32_qc8w( - input_channels, output_channels, input_channels, output_channels, kernel_scale.data(), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_scale.data(), kernel.data(), bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -3387,105 +3099,136 @@ TEST_F(FullyConnectedTestQD8F32QC8W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc8w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc8w( + fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qd8_f32_qc8w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qc8w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value( - subgraph, xnn_datatype_qcint8, kernel_scale.data(), kernel_dims.size(), /*channel_dim=*/0, - kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value( + subgraph, xnn_datatype_qcint8, kernel_scale.data(), + kernel_dims.size(), /*channel_dim=*/0, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); + // Don't look for exact matches since the result may differ if the subgraph + // gets automatically converted from `qd8` to `qp8`. + float max_abs_val = 1.0f; + for (float val : operator_output) { + max_abs_val = std::max(max_abs_val, std::abs(val)); + } + for (size_t i = 0; i < operator_output.size(); i++) { + ASSERT_NEAR(subgraph_output[i], operator_output[i], max_abs_val * 1e-3); + } } -class FullyConnectedTestQD8F32QC4W : public FullyConnectedTestBase { -}; +class FullyConnectedTestQD8F32QC4W + : public FullyConnectedTestBase {}; -TEST_F(FullyConnectedTestQD8F32QC4W, define) -{ +TEST_F(FullyConnectedTestQD8F32QC4W, define) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnnpack::Buffer requantization_scales(output_channels, 1.0f); xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); const uint8_t kernel_zero_point = 8; xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + kernel_scale.data(), kernel_dims.size(), + /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -3501,29 +3244,34 @@ TEST_F(FullyConnectedTestQD8F32QC4W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_parameters) -{ +TEST_F(FullyConnectedTestQD8F32QC4W, + internally_allocated_dynamic_quantization_parameters) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + xnnpack::Buffer convert_input(batch_size * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer subgraph_output(batch_size * output_channels); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -3534,23 +3282,32 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, kernel_zero_point, kernel_scale.data(), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), + output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -3558,79 +3315,118 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc4w( + fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qd8_f32_qc4w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qc4w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + kernel_scale.data(), kernel_dims.size(), + /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - EXPECT_EQ(subgraph_output, operator_output); + float max_abs_val = 0.0f; + for (size_t i = 0; i < operator_output.size(); i++) { + max_abs_val = std::max(max_abs_val, std::abs(operator_output[i])); + } + for (size_t i = 0; i < operator_output.size(); i++) { + ASSERT_NEAR(operator_output[i], subgraph_output[i], max_abs_val * 1e-2); + } } -TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_parameters_with_reshape) -{ +TEST_F(FullyConnectedTestQD8F32QC4W, + internally_allocated_dynamic_quantization_parameters_with_reshape) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + std::vector subgraph_input_dims(input_dims); + std::vector subgraph_output_dims(output_dims); + subgraph_input_dims[0] += 4; + subgraph_output_dims[0] += 4; + const size_t subgraph_batch_size = + NumElements(subgraph_input_dims) / input_channels; + xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer subgraph_output(batch_size * output_channels); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + xnnpack::Buffer convert_input(subgraph_batch_size * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + xnnpack::Buffer subgraph_output(subgraph_batch_size * output_channels); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -3641,23 +3437,32 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, kernel_zero_point, kernel_scale.data(), - kernel.data(), nullptr, output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_zero_point, kernel_scale.data(), kernel.data(), nullptr, + output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -3665,11 +3470,14 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc4w( + fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qd8_f32_qc4w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qc4w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. // @@ -3682,117 +3490,153 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p // Subgraph.......' | // Input3.................' - std::vector subgraph_input_dims(input_dims); - std::vector subgraph_output_dims(output_dims); - subgraph_input_dims[0] += 3; - subgraph_output_dims[0] += 3; - - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, subgraph_input_dims.size(), subgraph_input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, subgraph_input_dims.size(), + subgraph_input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, subgraph_input_dims.size(), /*num_nonbatch_dims=*/1, subgraph_input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, subgraph_input_dims.size(), + /*num_nonbatch_dims=*/1, subgraph_input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + kernel_scale.data(), kernel_dims.size(), + /*channel_dim=*/0, kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, subgraph_output_dims.size(), subgraph_output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, subgraph_output_dims.size(), + subgraph_output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, XNN_INVALID_NODE_ID, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected( + subgraph, output_min, output_max, dq_quantized_id, kernel_id, + XNN_INVALID_NODE_ID, output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); struct xnn_node* node = &subgraph->nodes[0]; ASSERT_EQ(node->type, xnn_node_type_convert); - const size_t dynamic_param_size = runtime->values[node->outputs[0]].quantization.dynamic_params_size; - ASSERT_GT(dynamic_param_size, 0); // 1st inference: lets start smaller than we planned memory for std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_dims.size(), input_dims.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_reshape_external_value(runtime, input_id, input_dims.size(), + input_dims.data())); ASSERT_EQ(xnn_status_success, xnn_reshape_runtime(runtime)); - ASSERT_EQ(xnn_status_success, xnn_setup_runtime_v2(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime_v2(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - EXPECT_EQ(subgraph_output, operator_output); - const size_t dynamic_param_size1 = runtime->values[node->outputs[0]].quantization.dynamic_params_size; - // No change in dynamic param size after the first inference - ASSERT_EQ(dynamic_param_size, dynamic_param_size1); + float max_abs_val = 0.0f; + for (size_t i = 0; i < operator_output.size(); i++) { + max_abs_val = std::max(max_abs_val, std::abs(operator_output[i])); + } + for (size_t i = 0; i < operator_output.size(); i++) { + ASSERT_NEAR(operator_output[i], subgraph_output[i], max_abs_val * 1e-2); + } - // 2nd inference: The dq-params should be properly allocated to handle a resize without memory retrigger + // 2nd inference: The dq-params should be properly allocated to handle a + // resize without memory retrigger input_dims[0] += 2; - size_t batch_size2 = std::accumulate(input_dims.begin(), input_dims.end() - 1, 1, std::multiplies()); - xnnpack::Buffer convert_input2(batch_size2 * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - std::generate(convert_input2.begin(), convert_input2.end(), [&]() { return f32dist(rng); }); + size_t batch_size2 = std::accumulate(input_dims.begin(), input_dims.end() - 1, + 1, std::multiplies()); + xnnpack::Buffer convert_input2(batch_size2 * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + std::generate(convert_input2.begin(), convert_input2.end(), + [&]() { return f32dist(rng); }); xnnpack::Buffer subgraph_output2(batch_size2 * output_channels); std::array external2 = { - xnn_external_value{input_id, convert_input2.data()}, xnn_external_value{output_id, subgraph_output2.data()}}; - ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_dims.size(), input_dims.data())); + xnn_external_value{input_id, convert_input2.data()}, + xnn_external_value{output_id, subgraph_output2.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_reshape_external_value(runtime, input_id, input_dims.size(), + input_dims.data())); ASSERT_EQ(xnn_status_success, xnn_reshape_runtime(runtime)); - ASSERT_EQ(xnn_status_success, xnn_setup_runtime_v2(runtime, external2.size(), external2.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime_v2(runtime, external2.size(), external2.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - const size_t dynamic_param_size2 = runtime->values[node->outputs[0]].quantization.dynamic_params_size; - // No change after the second inference - ASSERT_EQ(dynamic_param_size1, dynamic_param_size2); - - // 3rd inference: The dq-params should be properly allocated even with memory retrigger - input_dims[0] += 2; // +4 total - size_t batch_size3 = std::accumulate(input_dims.begin(), input_dims.end() - 1, 1, std::multiplies()); - xnnpack::Buffer convert_input3(batch_size3 * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - std::generate(convert_input3.begin(), convert_input3.end(), [&]() { return f32dist(rng); }); + + // 3rd inference: The dq-params should be properly allocated even with memory + // retrigger + input_dims[0] += 2; // +4 total + size_t batch_size3 = std::accumulate(input_dims.begin(), input_dims.end() - 1, + 1, std::multiplies()); + xnnpack::Buffer convert_input3(batch_size3 * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + std::generate(convert_input3.begin(), convert_input3.end(), + [&]() { return f32dist(rng); }); xnnpack::Buffer subgraph_output3(batch_size3 * output_channels); std::array external3 = { - xnn_external_value{input_id, convert_input3.data()}, xnn_external_value{output_id, subgraph_output3.data()}}; - ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_dims.size(), input_dims.data())); + xnn_external_value{input_id, convert_input3.data()}, + xnn_external_value{output_id, subgraph_output3.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_reshape_external_value(runtime, input_id, input_dims.size(), + input_dims.data())); ASSERT_EQ(xnn_status_success, xnn_reshape_runtime(runtime)); - ASSERT_EQ(xnn_status_success, xnn_setup_runtime_v2(runtime, external3.size(), external3.data())); + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime_v2(runtime, external3.size(), external3.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - const size_t dynamic_param_size3 = runtime->values[node->outputs[0]].quantization.dynamic_params_size; - // It should be larger after the third inference - ASSERT_LT(dynamic_param_size2, dynamic_param_size3); } -TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_parameters_transposed_weights) -{ +// TODO(b/381381604) - Re-enable once fixed. +TEST_F( + FullyConnectedTestQD8F32QC4W, + DISABLED_internally_allocated_dynamic_quantization_parameters_transposed_weights) { ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); + xnnpack::Buffer convert_input(batch_size * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer subgraph_output(batch_size * output_channels); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. + // Adjust number of kernel elements for QC4W. input_channels should be padded + // to byte boundary, hence even. const size_t rounded_output_channels = round_up_po2(output_channels, 2); kernel = xnnpack::Buffer(input_channels * rounded_output_channels); xnnpack::Buffer kernel_scale(output_channels); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -3803,23 +3647,32 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; xnn_status status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); status = xnn_create_fully_connected_nc_qd8_f32_qc4w( - input_channels, output_channels, input_channels, output_channels, kernel_zero_point, kernel_scale.data(), - kernel.data(), bias.data(), output_min, output_max, - XNN_FLAG_TRANSPOSE_WEIGHTS, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + input_channels, output_channels, input_channels, output_channels, + kernel_zero_point, kernel_scale.data(), kernel.data(), bias.data(), + output_min, output_max, XNN_FLAG_TRANSPOSE_WEIGHTS, nullptr, nullptr, + &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -3827,350 +3680,87 @@ TEST_F(FullyConnectedTestQD8F32QC4W, internally_allocated_dynamic_quantization_p ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qc4w( + fc_op, batch_size, /*threadpool=*/nullptr)); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qc4w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); - - // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_NODE_ID); - - uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); - ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_channelwise_quantized_tensor_value_v2( - subgraph, xnn_datatype_qcint4, kernel_zero_point, kernel_scale.data(), kernel_dims_tranposed.size(), - /*channel_dim=*/1, kernel_dims_tranposed.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_NODE_ID); - - xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, XNN_FLAG_TRANSPOSE_WEIGHTS)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); - ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); - std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - - EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); -} - -class FullyConnectedTestQD8F32QB4W : public FullyConnectedTestBase { -}; - -TEST_F(FullyConnectedTestQD8F32QB4W, define) -{ - size_t block_size = 32; - input_channels = round_up_po2(input_channels, block_size); - - input_dims[input_dims.size() - 1] = input_channels; - kernel_dims[kernel_dims.size() - 1] = input_channels; - - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - - uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - - // Adjust number of kernel elements for QB4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); - const uint8_t kernel_zero_point = 8; - xnnpack::Buffer kernel_scale(output_channels * block_size); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_blockwise_quantized_tensor_value( - subgraph, xnn_datatype_qbint4, kernel_zero_point, reinterpret_cast(kernel_scale.data()), kernel_dims.size(), - /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - - uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); - - ASSERT_EQ(subgraph->num_nodes, 1); - const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_fully_connected); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); - ASSERT_EQ(node->num_inputs, 3); - ASSERT_EQ(node->inputs[0], input_id); - ASSERT_EQ(node->inputs[1], kernel_id); - ASSERT_EQ(node->inputs[2], bias_id); - ASSERT_EQ(node->num_outputs, 1); - ASSERT_EQ(node->outputs[0], output_id); - ASSERT_EQ(node->flags, 0); -} - -TEST_F(FullyConnectedTestQD8F32QB4W, internally_allocated_dynamic_quantization_parameters) -{ - size_t block_size = 32; - input_channels = round_up_po2(input_channels, block_size); - - input_dims[input_dims.size() - 1] = input_channels; - kernel_dims[kernel_dims.size() - 1] = input_channels; - - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data(batch_size * input_channels + XNN_EXTRA_BYTES); - xnnpack::Buffer subgraph_output(batch_size * output_channels); - xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - - xnnpack::Buffer kernel_scale(output_channels * block_size); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return scale_dist(rng); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); - - // Adjust number of kernel elements for QC4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); - - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); - - const uint8_t kernel_zero_point = 8; - - // Call operator API. - xnn_operator_t convert_op = nullptr; - xnn_operator_t fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qd8(convert_op, batch_size, input_channels, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), - operator_dq_data.data(), quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); - - status = xnn_create_fully_connected_nc_qd8_f32_qb4w( - input_channels, output_channels, input_channels, output_channels, block_size, kernel_zero_point, reinterpret_cast(kernel_scale.data()), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); - - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qb4w(fc_op, batch_size, /*threadpool=*/nullptr)); + xnn_setup_fully_connected_nc_qd8_f32_qc4w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qb4w(fc_op, operator_dq_data.data(), operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_blockwise_quantized_tensor_value( - subgraph, xnn_datatype_qbint4, kernel_zero_point, reinterpret_cast(kernel_scale.data()), kernel_dims.size(), - /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); - - uint32_t bias_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_NODE_ID); - - xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); - ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); - std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); - ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); -} - -TEST_F(FullyConnectedTestF32, reshape) -{ - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - - std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); - std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - - xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); - - uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, - /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); - ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - - uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, - xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, kernel_dims.size(), kernel_dims.data(), kernel.data(), /*external_id=*/1, - /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_channelwise_quantized_tensor_value_v2( + subgraph, xnn_datatype_qcint4, kernel_zero_point, + kernel_scale.data(), kernel_dims_tranposed.size(), + /*channel_dim=*/1, kernel_dims_tranposed.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); - - uint32_t output_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); - ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - - ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); - - ASSERT_EQ(subgraph->num_nodes, 1); - ASSERT_EQ(subgraph->num_values, 4); - const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_fully_connected); - ASSERT_EQ(node->activation.output_min, output_min); - ASSERT_EQ(node->activation.output_max, output_max); - ASSERT_EQ(node->num_inputs, 3); - ASSERT_EQ(node->inputs[0], input_id); - ASSERT_EQ(node->inputs[1], kernel_id); - ASSERT_EQ(node->inputs[2], bias_id); - ASSERT_EQ(node->num_outputs, 1); - ASSERT_EQ(node->outputs[0], output_id); - ASSERT_EQ(node->flags, 0); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + uint32_t output_id = XNN_INVALID_NODE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); + ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, XNN_FLAG_TRANSPOSE_WEIGHTS)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - std::vector new_input_dims(input_dims.begin(), input_dims.end()); - const xnn_shape* output_shape = &runtime->values[node->outputs[0]].shape; - const xnn_shape* kernel_shape = &runtime->values[node->inputs[1]].shape; - - - // case 1 : no change in input dims - ASSERT_EQ(xnn_status_success, node->reshape(&runtime->opdata[0], subgraph->values, subgraph->num_values, nullptr /*threadpool*/)); - - // case 2: Resize input to a smaller size. This should not require memory planning - for (size_t i=0; ireshape(&runtime->opdata[0], runtime->values, runtime->num_values, nullptr /*threadpool*/)); - - // Check that the output shape is correct - for (size_t i=0; inum_dims - 1; ++i) { - ASSERT_EQ(output_shape->dim[i], new_input_dims[i]); - } - ASSERT_EQ(output_shape->dim[output_shape->num_dims-1], kernel_shape->dim[0]); - - // case 3: Resize input to a larger size. This should require memory planning - for (size_t i=0; ireshape(&runtime->opdata[0], runtime->values, runtime->num_values, nullptr /*threadpool*/)); - - // Check that the output shape is correct - for (size_t i=0; inum_dims - 1; ++i) { - ASSERT_EQ(output_shape->dim[i], new_input_dims[i]); + for (size_t i = 0; i < operator_output.size(); i++) { + ASSERT_NEAR(operator_output[i], subgraph_output[i], max_abs_val * 0.0625); } - ASSERT_EQ(output_shape->dim[output_shape->num_dims-1], kernel_shape->dim[0]); - - size_t num_output_elements = std::accumulate(new_input_dims.begin(), new_input_dims.end() - 1, size_t{1}, std::multiplies()) * kernel_shape->dim[0]; - ASSERT_EQ(runtime->values[node->outputs[0]].size, num_output_elements * sizeof(float)); } -TEST_F(FullyConnectedTestQP8F32QB4W, define) -{ - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - if (xnn_init_qp8_f32_qb4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } +class FullyConnectedTestQD8F32QB4W + : public FullyConnectedTestBase {}; + +TEST_F(FullyConnectedTestQD8F32QB4W, define) { size_t block_size = 32; input_channels = round_up_po2(input_channels, block_size); @@ -4181,43 +3771,51 @@ TEST_F(FullyConnectedTestQP8F32QB4W, define) xnn_subgraph_t subgraph = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qpint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - /*external_id=*/0, /*flags=*/0, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + /*external_id=*/0, /*flags=*/0, &input_id)); ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - // Adjust number of kernel elements for QB4W. input_channels should be padded to byte boundary, hence even. - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + // We adjusted input_channels above, reallocate the kernel. + kernel = xnnpack::Buffer(output_channels * input_channels); const uint8_t kernel_zero_point = 8; - xnnpack::Buffer kernel_scale(output_channels * block_size); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); }); + xnnpack::Buffer kernel_scale(output_channels * block_size); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_blockwise_quantized_tensor_value( - subgraph, xnn_datatype_qbint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_blockwise_quantized_tensor_value( + subgraph, xnn_datatype_qbint4, kernel_zero_point, + reinterpret_cast(kernel_scale.data()), + kernel_dims.size(), + /*channel_dim=*/0, block_size, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/0, &output_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/0, &output_id)); ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); ASSERT_EQ( - xnn_status_success, - xnn_define_fully_connected(subgraph, output_min, output_max, input_id, kernel_id, bias_id, output_id, /*flags=*/0)); + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; @@ -4233,12 +3831,8 @@ TEST_F(FullyConnectedTestQP8F32QB4W, define) ASSERT_EQ(node->flags, 0); } -TEST_F(FullyConnectedTestQP8F32QB4W, matches_operator_api) -{ - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - if (xnn_init_qp8_f32_qb4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } +TEST_F(FullyConnectedTestQD8F32QB4W, + internally_allocated_dynamic_quantization_parameters) { size_t block_size = 32; input_channels = round_up_po2(input_channels, block_size); @@ -4247,25 +3841,33 @@ TEST_F(FullyConnectedTestQP8F32QB4W, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); xnn_subgraph_t subgraph = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph)); - std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(/*external_value_ids=*/4, + /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); uint32_t input_id = XNN_INVALID_NODE_ID; - xnnpack::Buffer convert_input(batch_size * input_channels + XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_dq_data( - xnn_x8_packq_f32qp8_gemm_packed_size(batch_size, input_channels)); + xnnpack::Buffer convert_input(batch_size * input_channels + + XNN_EXTRA_BYTES / sizeof(float)); + xnnpack::Buffer operator_dq_data(batch_size * input_channels + + XNN_EXTRA_BYTES); xnnpack::Buffer subgraph_output(batch_size * output_channels); xnnpack::Buffer operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); + xnnpack::Buffer quantization_params( + batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - xnnpack::Buffer kernel_scale(output_channels * block_size); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); }); + xnnpack::Buffer kernel_scale(output_channels * block_size); + std::generate(kernel_scale.begin(), kernel_scale.end(), + [&]() { return scale_dist(rng); }); std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), [&]() { return f32dist(rng); }); - std::generate(quantization_params.begin(), quantization_params.end(), [&]() { return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); + std::generate(convert_input.begin(), convert_input.end(), + [&]() { return f32dist(rng); }); + std::generate(quantization_params.begin(), quantization_params.end(), [&]() { + return xnn_quantization_params{w8dist(rng), f32dist(rng)}; + }); - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + // We adjusted input_channels above, reallocate the kernel. + kernel = xnnpack::Buffer(output_channels * input_channels); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); @@ -4275,24 +3877,34 @@ TEST_F(FullyConnectedTestQP8F32QB4W, matches_operator_api) // Call operator API. xnn_operator_t convert_op = nullptr; xnn_operator_t fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qp8( - /*flags=*/0, &convert_op); - std::unique_ptr auto_convert_op(convert_op, xnn_delete_operator); + xnn_status status = xnn_create_convert_nc_f32_qd8( + /*flags=*/0, &convert_op); + std::unique_ptr auto_convert_op( + convert_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8(convert_op, batch_size, input_channels, input_channels, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_convert_nc_f32_qp8(convert_op, convert_input.data(), - operator_dq_data.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(convert_op, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_reshape_convert_nc_f32_qd8( + convert_op, batch_size, input_channels, input_channels, + input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_convert_nc_f32_qd8(convert_op, convert_input.data(), + operator_dq_data.data(), + quantization_params.data())); + ASSERT_EQ(xnn_status_success, + xnn_run_operator(convert_op, /*threadpool=*/nullptr)); - status = xnn_create_fully_connected_nc_qp8_f32_qb4w( - input_channels, output_channels, input_channels, output_channels, block_size, kernel_zero_point, kernel_scale.data(), - kernel.data(), bias.data(), output_min, output_max, - /*flags=*/0, nullptr, nullptr, &fc_op); - std::unique_ptr auto_fc_op(fc_op, xnn_delete_operator); + status = xnn_create_fully_connected_nc_qd8_f32_qb4w( + input_channels, output_channels, input_channels, output_channels, + block_size, kernel_zero_point, + reinterpret_cast(kernel_scale.data()), kernel.data(), + bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &fc_op); + std::unique_ptr auto_fc_op( + fc_op, xnn_delete_operator); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); @@ -4300,197 +3912,207 @@ TEST_F(FullyConnectedTestQP8F32QB4W, matches_operator_api) ASSERT_EQ(xnn_status_success, status); ASSERT_NE(nullptr, fc_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qp8_f32_qb4w(fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_qd8_f32_qb4w( + fc_op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, + xnn_setup_fully_connected_nc_qd8_f32_qb4w( + fc_op, operator_dq_data.data(), operator_output.data(), + quantization_params.data())); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qp8_f32_qb4w(fc_op, operator_dq_data.data(), operator_output.data())); - ASSERT_EQ(xnn_status_success, xnn_run_operator(fc_op, /*threadpool=*/nullptr)); + xnn_run_operator(fc_op, /*threadpool=*/nullptr)); // Call subgraph API. - ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, input_dims.size(), input_dims.data(), nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, /*external_id=*/0, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); ASSERT_NE(input_id, XNN_INVALID_NODE_ID); uint32_t dq_quantized_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_dynamically_quantized_tensor_value( - subgraph, xnn_datatype_qdint8, input_dims.size(), /*num_nonbatch_dims=*/1, input_dims.data(), - XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_dynamically_quantized_tensor_value( + subgraph, xnn_datatype_qdint8, input_dims.size(), + /*num_nonbatch_dims=*/1, input_dims.data(), + XNN_INVALID_VALUE_ID, /*flags=*/0, &dq_quantized_id)); ASSERT_NE(dq_quantized_id, XNN_INVALID_NODE_ID); uint32_t kernel_id = XNN_INVALID_VALUE_ID; - ASSERT_EQ( - xnn_status_success, xnn_define_blockwise_quantized_tensor_value( - subgraph, xnn_datatype_qbint4, kernel_zero_point, kernel_scale.data(), kernel_dims.size(), - /*channel_dim=*/0, block_size, kernel_dims.data(), kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_blockwise_quantized_tensor_value( + subgraph, xnn_datatype_qbint4, kernel_zero_point, + reinterpret_cast(kernel_scale.data()), + kernel_dims.size(), + /*channel_dim=*/0, block_size, kernel_dims.data(), + kernel.data(), /*external_id=*/1, /*flags=*/0, &kernel_id)); uint32_t bias_id = XNN_INVALID_VALUE_ID; ASSERT_EQ( - xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, bias_dims.size(), bias_dims.data(), bias.data(), - /*external_id=*/2, /*flags=*/0, &bias_id)); + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); uint32_t output_id = XNN_INVALID_NODE_ID; - ASSERT_EQ( xnn_status_success, xnn_define_tensor_value( - subgraph, xnn_datatype_fp32, output_dims.size(), output_dims.data(), nullptr, - /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); ASSERT_NE(output_id, XNN_INVALID_NODE_ID); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM)); - ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, - kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + input_id, dq_quantized_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, + dq_quantized_id, kernel_id, bias_id, + output_id, /*flags=*/0)); + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); - std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); std::array external = { - xnn_external_value{input_id, convert_input.data()}, xnn_external_value{output_id, subgraph_output.data()}}; - ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data())); + xnn_external_value{input_id, convert_input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - EXPECT_EQ(subgraph_output, operator_output); } -TEST_F(FullyConnectedTestQP8F32QB4W, matches_qd8_f32_qb4w) -{ - const size_t block_size = 32; - input_channels = round_up_po2(input_channels, block_size); +TEST_F(FullyConnectedTestF32, reshape) { + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); - input_dims[input_dims.size() - 1] = input_channels; - kernel_dims[kernel_dims.size() - 1] = input_channels; + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); + std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); - if (xnn_init_qp8_f32_qb4w_gemm_config() == nullptr) { - GTEST_SKIP(); - } + uint32_t input_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, + &input_id)); + ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); - xnnpack::Buffer convert_input(batch_size * input_channels + - XNN_EXTRA_BYTES / sizeof(float)); - xnnpack::Buffer operator_qp8_data( - xnn_x8_packq_f32qp8_gemm_packed_size(batch_size, input_channels) + - XNN_EXTRA_BYTES); - xnnpack::Buffer operator_qd8_data(batch_size * input_channels + - XNN_EXTRA_BYTES); - xnnpack::Buffer qp8_operator_output(batch_size * output_channels); - xnnpack::Buffer qd8_operator_output(batch_size * output_channels); - xnnpack::Buffer quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS); - - const size_t rounded_input_channels = round_up_po2(input_channels, 2); - const size_t num_blocks = rounded_input_channels / block_size; - xnnpack::Buffer kernel_scale(output_channels * num_blocks); - std::generate(kernel_scale.begin(), kernel_scale.end(), [&]() { return math_cvt_bf16_fp32(scale_dist(rng)); }); - std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); - std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); - std::generate(convert_input.begin(), convert_input.end(), - [&]() { return f32dist(rng); }); + uint32_t kernel_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, + kernel_dims.size(), kernel_dims.data(), + kernel.data(), /*external_id=*/1, + /*flags=*/0, &kernel_id)); - kernel = xnnpack::Buffer(output_channels * rounded_input_channels); + uint32_t bias_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + uint32_t output_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, + &output_id)); + ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); - const float output_min = -std::numeric_limits::infinity(); - const float output_max = std::numeric_limits::infinity(); + ASSERT_EQ( + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); - const uint8_t kernel_zero_point = 8; + ASSERT_EQ(subgraph->num_nodes, 1); + ASSERT_EQ(subgraph->num_values, 4); + const struct xnn_node* node = &subgraph->nodes[0]; + ASSERT_EQ(node->type, xnn_node_type_fully_connected); + ASSERT_EQ(node->activation.output_min, output_min); + ASSERT_EQ(node->activation.output_max, output_max); + ASSERT_EQ(node->num_inputs, 3); + ASSERT_EQ(node->inputs[0], input_id); + ASSERT_EQ(node->inputs[1], kernel_id); + ASSERT_EQ(node->inputs[2], bias_id); + ASSERT_EQ(node->num_outputs, 1); + ASSERT_EQ(node->outputs[0], output_id); + ASSERT_EQ(node->flags, 0); - // Call operator API for `qp8`. - xnn_operator_t qp8_convert_op = nullptr; - xnn_operator_t qp8_fc_op = nullptr; - xnn_status status = xnn_create_convert_nc_f32_qp8( - /*flags=*/0, &qp8_convert_op); - std::unique_ptr - auto_qp8_convert_op(qp8_convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qp8_convert_op); - ASSERT_EQ(xnn_status_success, xnn_reshape_convert_nc_f32_qp8( - qp8_convert_op, batch_size, input_channels, - input_channels, /*threadpool=*/nullptr)); + xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qp8(qp8_convert_op, convert_input.data(), - operator_qp8_data.data())); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); + std::array external = { + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; ASSERT_EQ(xnn_status_success, - xnn_run_operator(qp8_convert_op, /*threadpool=*/nullptr)); + xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); - status = xnn_create_fully_connected_nc_qp8_f32_qb4w( - input_channels, output_channels, input_channels, output_channels, block_size, - kernel_zero_point, reinterpret_cast(kernel_scale.data()), kernel.data(), bias.data(), - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &qp8_fc_op); - std::unique_ptr auto_qp8_fc_op( - qp8_fc_op, xnn_delete_operator); + std::vector new_input_dims(input_dims.begin(), input_dims.end()); + const xnn_shape* output_shape = &runtime->values[node->outputs[0]].shape; + const xnn_shape* kernel_shape = &runtime->values[node->inputs[1]].shape; - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); - } + // case 1 : no change in input dims + ASSERT_EQ(xnn_status_success, + node->reshape(&runtime->opdata[0], subgraph->values, + subgraph->num_values, nullptr /*threadpool*/)); - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qp8_fc_op); - ASSERT_EQ(xnn_status_success, - xnn_reshape_fully_connected_nc_qp8_f32_qb4w( - qp8_fc_op, batch_size, /*threadpool=*/nullptr)); - ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_qp8_f32_qb4w( - qp8_fc_op, operator_qp8_data.data(), - qp8_operator_output.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(qp8_fc_op, /*threadpool=*/nullptr)); - - // Call operator API for `qd8`. - xnn_operator_t qd8_convert_op = nullptr; - xnn_operator_t qd8_fc_op = nullptr; - status = xnn_create_convert_nc_f32_qd8( - /*flags=*/0, &qd8_convert_op); - std::unique_ptr - auto_qd8_convert_op(qd8_convert_op, xnn_delete_operator); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); + // case 2: Resize input to a smaller size. This should not require memory + // planning + for (size_t i = 0; i < new_input_dims.size() - 1; ++i) { + new_input_dims[i] = input_dims[i] - 1; } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qd8_convert_op); - ASSERT_EQ(xnn_status_success, - xnn_reshape_convert_nc_f32_qd8( - qd8_convert_op, batch_size, input_channels, input_channels, - input_channels, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, - xnn_setup_convert_nc_f32_qd8(qd8_convert_op, convert_input.data(), - operator_qd8_data.data(), - quantization_params.data())); + xnn_reshape_external_value(runtime, input_id, new_input_dims.size(), + new_input_dims.data())); + ASSERT_EQ(xnn_status_success, - xnn_run_operator(qd8_convert_op, /*threadpool=*/nullptr)); + node->reshape(&runtime->opdata[0], runtime->values, + runtime->num_values, nullptr /*threadpool*/)); - status = xnn_create_fully_connected_nc_qd8_f32_qb4w( - input_channels, output_channels, input_channels, output_channels, block_size, - kernel_zero_point, reinterpret_cast(kernel_scale.data()), kernel.data(), bias.data(), - output_min, output_max, - /*flags=*/0, nullptr, nullptr, &qd8_fc_op); - std::unique_ptr auto_qd8_fc_op( - qd8_fc_op, xnn_delete_operator); + // Check that the output shape is correct + for (size_t i = 0; i < output_shape->num_dims - 1; ++i) { + ASSERT_EQ(output_shape->dim[i], new_input_dims[i]); + } + ASSERT_EQ(output_shape->dim[output_shape->num_dims - 1], + kernel_shape->dim[0]); - if (status == xnn_status_unsupported_hardware) { - GTEST_SKIP(); + // case 3: Resize input to a larger size. This should require memory planning + for (size_t i = 0; i < new_input_dims.size() - 1; ++i) { + new_input_dims[i] = input_dims[i] + 1; } - ASSERT_EQ(xnn_status_success, status); - ASSERT_NE(nullptr, qd8_fc_op); - ASSERT_EQ(xnn_status_success, - xnn_reshape_fully_connected_nc_qd8_f32_qb4w( - qd8_fc_op, batch_size, /*threadpool=*/nullptr)); ASSERT_EQ(xnn_status_success, - xnn_setup_fully_connected_nc_qd8_f32_qb4w( - qd8_fc_op, operator_qd8_data.data(), qd8_operator_output.data(), - quantization_params.data())); - ASSERT_EQ(xnn_status_success, - xnn_run_operator(qd8_fc_op, /*threadpool=*/nullptr)); + xnn_reshape_external_value(runtime, input_id, new_input_dims.size(), + new_input_dims.data())); - // Compare the outputs. Note that the values will not be exactly the same - // since the `qd8` quantization rounds to zero, whereas the `qp8` quantization - // does not. - float max_abs_val = 0.0f; - for (size_t i = 0; i < batch_size * output_channels; i++) { - max_abs_val = std::max(max_abs_val, std::abs(qd8_operator_output[i])); - } - for (size_t i = 0; i < batch_size * output_channels; i++) { - ASSERT_NEAR(qp8_operator_output[i], qd8_operator_output[i], - max_abs_val * 1e-2); + ASSERT_EQ(xnn_status_reallocation_required, + node->reshape(&runtime->opdata[0], runtime->values, + runtime->num_values, nullptr /*threadpool*/)); + + // Check that the output shape is correct + for (size_t i = 0; i < output_shape->num_dims - 1; ++i) { + ASSERT_EQ(output_shape->dim[i], new_input_dims[i]); } + ASSERT_EQ(output_shape->dim[output_shape->num_dims - 1], + kernel_shape->dim[0]); + + size_t num_output_elements = + std::accumulate(new_input_dims.begin(), new_input_dims.end() - 1, + size_t{1}, std::multiplies()) * + kernel_shape->dim[0]; + ASSERT_EQ(runtime->values[node->outputs[0]].size, + num_output_elements * sizeof(float)); } diff --git a/test/gemm-microkernel-tester.cc b/test/gemm-microkernel-tester.cc index 26e76bd11135..5d7ef7b37571 100644 --- a/test/gemm-microkernel-tester.cc +++ b/test/gemm-microkernel-tester.cc @@ -112,7 +112,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float) * 2)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n(), 0); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(int8_t)); xnnpack::Buffer im2col(mr() * ks()); @@ -169,7 +169,11 @@ void GemmMicrokernelTester::Test( } std::shuffle(im2col.begin(), im2col.end(), rng); const size_t k_stride = round_up_po2(k(), kr() * sr()); - xnnpack::Buffer zero_points(k_stride + XNN_EXTRA_BYTES, quantization_params[0].zero_point); + int32_t zp = quantization_params[0].zero_point; + if (unsigned_inputs()) { + zp += 128; + } + xnnpack::Buffer zero_points(k_stride + XNN_EXTRA_BYTES, zp); const int8_t* zero_sentinel = (const int8_t*) &packing_params; const int8_t* zero_data = zero_points.data(); if (zero_index() != SIZE_MAX) { @@ -219,9 +223,20 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } igemm(m(), n(), k(), ks() * mr() * sizeof(void*), im2col.data(), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(xnn_float16), cn_stride() * sizeof(xnn_float16), + c.data(), cm_stride() * sizeof(xnn_float16), nr() * sizeof(xnn_float16), a_offset() * sizeof(uint8_t), zero_sentinel, zero_data, ¶ms, quantization_params.data()); @@ -229,9 +244,9 @@ void GemmMicrokernelTester::Test( for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << "), optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -261,8 +276,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float) * 2)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); - xnnpack::Buffer acc(m() * n()); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n(), 0); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(int8_t)); xnnpack::Buffer im2col(mr() * ks()); @@ -319,7 +333,11 @@ void GemmMicrokernelTester::Test( } std::shuffle(im2col.begin(), im2col.end(), rng); const size_t k_stride = round_up_po2(k(), kr() * sr()); - xnnpack::Buffer zero_points(k_stride + XNN_EXTRA_BYTES, quantization_params[0].zero_point); + int32_t zp = quantization_params[0].zero_point; + if (unsigned_inputs()) { + zp += 128; + } + xnnpack::Buffer zero_points(k_stride + XNN_EXTRA_BYTES, zp); const int8_t* zero_sentinel = (const int8_t*) &packing_params; const int8_t* zero_data = zero_points.data(); if (zero_index() != SIZE_MAX) { @@ -369,9 +387,20 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } igemm(m(), n(), k(), ks() * mr() * sizeof(void*), im2col.data(), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), a_offset() * sizeof(uint8_t), zero_sentinel, zero_data, ¶ms, quantization_params.data()); @@ -379,10 +408,9 @@ void GemmMicrokernelTester::Test( for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -405,7 +433,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n()); @@ -448,7 +476,7 @@ void GemmMicrokernelTester::Test( m(), n(), k(), a.data(), a_stride() * sizeof(uint8_t), packed_w.data(), - c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t), + c.data(), cm_stride() * sizeof(uint8_t), nr() * sizeof(uint8_t), &quantization_params); for (size_t m_index = 0; m_index < m(); m_index++) { @@ -460,12 +488,12 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax())); - EXPECT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin())); - EXPECT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j])) + EXPECT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), uint32_t(qmax())); + EXPECT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), uint32_t(qmin())); + EXPECT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), uint32_t(c_ref[i * n() + j])) << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j] << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point); } @@ -490,7 +518,7 @@ void GemmMicrokernelTester::Test( ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t)); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(uint8_t)); @@ -562,7 +590,7 @@ void GemmMicrokernelTester::Test( igemm( m(), n(), k(), ks() * mr() * sizeof(void*), im2col.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t), + c.data(), cm_stride() * sizeof(uint8_t), nr() * sizeof(uint8_t), a_offset() * sizeof(uint8_t), zero_pointer, &quantization_params); @@ -575,12 +603,12 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax())); - EXPECT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin())); - EXPECT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j])) + EXPECT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), uint32_t(qmax())); + EXPECT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), uint32_t(qmin())); + EXPECT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), uint32_t(c_ref[i * n() + j])) << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j]) << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point); } @@ -608,7 +636,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer scale(n()); xnnpack::Buffer c_ref(m() * n()); @@ -667,7 +695,7 @@ void GemmMicrokernelTester::Test( m(), n(), k(), a.data(), a_stride() * sizeof(int8_t), packed_data, - c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t), + c.data(), cm_stride() * sizeof(int8_t), nr() * sizeof(int8_t), &minmax_params); for (size_t m_index = 0; m_index < m(); m_index++) { @@ -679,12 +707,12 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80); - EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80); - EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j])) + EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmax()) - 0x80); + EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmin()) - 0x80); + EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(c_ref[i * n() + j])) << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j]) << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point); } @@ -712,7 +740,7 @@ void GemmMicrokernelTester::Test( ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t)); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer scale(n()); xnnpack::Buffer c_ref(m() * n()); @@ -798,7 +826,7 @@ void GemmMicrokernelTester::Test( igemm( m(), n(), k(), ks() * mr() * sizeof(void*), im2col.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t), + c.data(), cm_stride() * sizeof(int8_t), nr() * sizeof(int8_t), a_offset() * sizeof(uint8_t), zero_pointer, &minmax_params); @@ -811,12 +839,12 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80); - EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80); - EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j])) + EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmax()) - 0x80); + EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmin()) - 0x80); + EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(c_ref[i * n() + j])) << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j]) << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point); } @@ -847,7 +875,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float) * 2)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n(), 0); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -935,18 +963,29 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } gemm(m(), n(), k(), a.data(), a_stride() * sizeof(int8_t), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(xnn_float16), cn_stride() * sizeof(xnn_float16), ¶ms, quantization_params.data()); + c.data(), cm_stride() * sizeof(xnn_float16), nr() * sizeof(xnn_float16), ¶ms, quantization_params.data()); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << "), optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -976,7 +1015,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float) * 2)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n(), 0); @@ -1065,19 +1104,30 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } gemm(m(), n(), k(), a.data(), a_stride() * sizeof(int8_t), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), ¶ms, quantization_params.data()); + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms, quantization_params.data()); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -1112,7 +1162,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k_bytes + packed_n() * (sizeof(int32_t) + sizeof(float) * 2)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n(), 0.0f); @@ -1201,19 +1251,30 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } gemm(m(), n(), k2, a.data(), a_stride() * sizeof(int8_t), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(xnn_float16), cn_stride() * sizeof(xnn_float16), ¶ms, quantization_params.data()); + c.data(), cm_stride() * sizeof(xnn_float16), nr() * sizeof(xnn_float16), ¶ms, quantization_params.data()); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -1254,7 +1315,7 @@ void GemmMicrokernelTester::Test( /* scales */ packed_n() * num_blocks * sizeof(xnn_bfloat16) + /* bias */ packed_n() * sizeof(float)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n(), 0); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -1367,15 +1428,15 @@ void GemmMicrokernelTester::Test( gemm(m(), n(), k2, a.data(), a_stride() * sizeof(int8_t), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(xnn_float16), cn_stride() * sizeof(xnn_float16), ¶ms, quantization_params.data()); + c.data(), cm_stride() * sizeof(xnn_float16), nr() * sizeof(xnn_float16), ¶ms, quantization_params.data()); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-3f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k2; } } @@ -1410,7 +1471,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k_bytes + packed_n() * (sizeof(int32_t) + sizeof(float) * 2)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n(), 0); @@ -1499,19 +1560,30 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } gemm(m(), n(), k2, a.data(), a_stride() * sizeof(int8_t), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), ¶ms, quantization_params.data()); + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms, quantization_params.data()); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k2; } } @@ -1551,7 +1623,7 @@ void GemmMicrokernelTester::Test( /* scales */ packed_n() * num_blocks * sizeof(float) + /* bias */ packed_n() * sizeof(float)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n(), 0); for (size_t iteration = 0; iteration < 1 /* iterations() */; iteration++) { @@ -1656,19 +1728,30 @@ void GemmMicrokernelTester::Test( } } + if (unsigned_inputs()) { + // Some architectures require that the input be unsigned. + // Adjust the zero point and flip the sign of the input to mimic adding + // 128 to the input with correct overflow behaviour. + for (int i = 0; i < quantization_params.size(); ++i) { + quantization_params[i].zero_point += 128; + } + for (int i = 0; i < a.size(); ++i) { + a[i] ^= 0x80; + } + } gemm(m(), n(), k2, a.data(), a_stride() * sizeof(int8_t), static_cast(packed_w.data()), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), ¶ms, quantization_params.data()); + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms, quantization_params.data()); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { // Extract tolerance into variable to workaround test failures on Linux AArch64. const float tolerance = std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-5f); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], tolerance) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k2; } } @@ -1698,7 +1781,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n(), 0.0f); xnnpack::Buffer kernel_scale(n()); xnnpack::Buffer c((mr() - 1) * cm_stride() + - ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n(), 0); @@ -1806,20 +1889,160 @@ void GemmMicrokernelTester::Test( // AArch64. const float tolerance = std::max(1.1e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); - ASSERT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + ASSERT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] << " (accumulator = " << acc[i * n() + j] << "), optimized = " - << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] + << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k2 - << ", cn_stride = " << cn_stride() + << ", nr = " << nr() << ", cm_stride = " << cm_stride(); } } } } +void GemmMicrokernelTester::Test_QP8F32QC8W( + xnn_qp8_f32_qc8w_gemm_minmax_ukernel_fn gemm, + xnn_init_f32_minmax_params_fn init_minmax_params, + xnn_pack_weights_and_biases_fn pack, + xnn_packed_stride_weights_and_biases_fn packed_stride) { + ASSERT_LE(m(), mr()); + + xnnpack::ReplicableRandomDevice rng; + auto f32rng = std::bind(std::uniform_real_distribution(-1.f, 1.f), + std::ref(rng)); + auto scalerng = std::bind(std::uniform_real_distribution(0.5f, 2.f), + std::ref(rng)); + auto w8rng = std::bind(std::uniform_int_distribution( + 0, std::numeric_limits::max()), + std::ref(rng)); + + xnnpack::Buffer input_f32(m() * k()); + xnnpack::Buffer b(n() * k()); + xnnpack::Buffer bias(n(), 0.0f); + xnnpack::Buffer kernel_scale(n()); + xnnpack::Buffer c((mr() - 1) * cm_stride() + + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); + xnnpack::Buffer acc(m() * n()); + xnnpack::Buffer c_ref(m() * n(), 0); + + // Create a fake `gemm_config` for the packing functions. + struct xnn_gemm_config gemm_config; + gemm_config.mr = static_cast(mr()); + gemm_config.mr_packed = static_cast(mr_packed()); + gemm_config.nr = static_cast(nr()); + gemm_config.log2_kr = static_cast(31 - math_clz_nonzero_u32(kr())); + gemm_config.log2_sr = static_cast(31 - math_clz_nonzero_u32(sr())); + + const size_t packed_w_stride = + packed_stride(&gemm_config, k(), /*k_stride=*/k(), /*extra_bytes=*/0); + const size_t packed_w_size = packed_w_stride * round_up(n(), nr()); + xnnpack::Buffer packed_w(packed_w_size); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input_f32.begin(), input_f32.end(), std::ref(f32rng)); + + // Quantize the left-hand operand. + const size_t input_packed_size = + xnn_x8_packq_f32qp8_packed_size(m(), k(), mr_packed(), kr(), sr()); + xnnpack::Buffer input_qp8(input_packed_size); + xnn_x8_packq_f32qp8_ukernel__scalar_u1(m(), k(), mr_packed(), kr(), sr(), + /*m_idx_start=*/0, input_f32.data(), + /*lhs_stride=*/k() * sizeof(float), + input_qp8.data()); + + std::generate(b.begin(), b.end(), std::ref(w8rng)); + std::generate(bias.begin(), bias.end(), std::ref(f32rng)); + std::generate(kernel_scale.begin(), kernel_scale.end(), std::ref(scalerng)); + std::fill(packed_w.begin(), packed_w.end(), 0); + + // RHS packing. + struct xnn_qs8_qc8w_packing_params params; + params.input_zero_point = 1; + params.scale_multiplier = 1.0f; + pack(/*flags=*/0, &gemm_config, k(), n(), + /*groups=*/1, /*k_stride=*/k(), + /*accumulator_init=*/nullptr, + /*weights=*/b.data(), + /*int_extra_data0_fn=*/nullptr, + /*extra_data0=*/bias.data(), + /*extra_data0_size=*/sizeof(float), + /*init_extra_data1_fn=*/ + nullptr, + /*extra_data1=*/kernel_scale.data(), + /*extra_data1_size=*/sizeof(float), + /*packed_weights_ptr=*/packed_w.data(), ¶ms); + + // Compute 32-bit results and output quantization arguments. + std::fill(c_ref.begin(), c_ref.end(), 0); + for (size_t m_index = 0; m_index < m(); m_index++) { + for (size_t n_index = 0; n_index < n(); n_index++) { + for (size_t k_index = 0; k_index < k(); k_index++) { + const size_t nb_index = (n_index * k() + k_index) / 2; + const int32_t bv = static_cast( + (k_index % 2 == 0) ? (b[nb_index] & UINT8_C(0xF)) + : (b[nb_index] >> 4)); + c_ref[m_index * n() + n_index] += + xnn_x8_packq_f32qp8_get_dequantized(m_index, k_index, + input_qp8.data(), k(), + mr_packed(), kr(), sr()) * + static_cast(bv); + } + c_ref[m_index * n() + n_index] *= kernel_scale[n_index]; + c_ref[m_index * n() + n_index] += bias[n_index]; + } + } + + const float accumulated_min = + *std::min_element(c_ref.cbegin(), c_ref.cend()); + const float accumulated_max = + *std::max_element(c_ref.cbegin(), c_ref.cend()); + const float c_min = + qmin() == std::numeric_limits::min() + ? -std::numeric_limits::infinity() + : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * + static_cast(qmin()); + const float c_max = + qmax() == std::numeric_limits::max() + ? std::numeric_limits::infinity() + : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * + static_cast(255 - qmax()); + + // Prepare parameters. + xnn_f32_minmax_params minmax_params; + init_minmax_params(&minmax_params, c_min, c_max); + + for (size_t m_index = 0; m_index < m(); m_index++) { + for (size_t n_index = 0; n_index < n(); n_index++) { + c_ref[m_index * n() + n_index] = + std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min); + } + } + + gemm(m(), n(), k(), input_qp8.data(), packed_w.data(), c.data(), + cm_stride() * sizeof(float), sizeof(float), &minmax_params); + + for (size_t i = 0; i < m(); i++) { + for (size_t j = 0; j < n(); j++) { + // Extract tolerance into variable to workaround test failures on Linux + // AArch64. + const float tolerance = + std::max(1.1e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); + ASSERT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], + c_ref[i * n() + j], tolerance) + << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] + << " (accumulator = " << acc[i * n() + j] << "), optimized = " + << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k() + << ", nr = " << nr() << ", cm_stride = " << cm_stride(); + } + } + } +} + void GemmMicrokernelTester::Test( xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn gemm, xnn_init_f32_qb4w_minmax_params_fn init_minmax_params, @@ -1847,7 +2070,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n(), 0.0f); xnnpack::Buffer kernel_scale2d(n() * packed_k2 / bl()); xnnpack::Buffer c((mr() - 1) * cm_stride() + - ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n(), 0); @@ -1975,14 +2198,14 @@ void GemmMicrokernelTester::Test( // AArch64. const float tolerance = std::max(1.1e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f); - ASSERT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + ASSERT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], tolerance) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] << " (accumulator = " << acc[i * n() + j] << "), optimized = " - << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] + << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k2 - << ", cn_stride = " << cn_stride() + << ", nr = " << nr() << ", cm_stride = " << cm_stride(); } } @@ -2008,7 +2231,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t)); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n()); @@ -2052,7 +2275,7 @@ void GemmMicrokernelTester::Test( m(), n(), k(), a.data(), a_stride() * sizeof(int8_t), packed_data, - c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t), + c.data(), cm_stride() * sizeof(int8_t), nr() * sizeof(int8_t), &quantization_params); for (size_t m_index = 0; m_index < m(); m_index++) { @@ -2064,12 +2287,12 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80); - EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80); - EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j])) + EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmax()) - 0x80); + EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmin()) - 0x80); + EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(c_ref[i * n() + j])) << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j]) << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point); } @@ -2097,7 +2320,7 @@ void GemmMicrokernelTester::Test( ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t)); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer acc(m() * n()); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(int8_t)); @@ -2169,7 +2392,7 @@ void GemmMicrokernelTester::Test( igemm( m(), n(), k(), ks() * mr() * sizeof(void*), im2col.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t), + c.data(), cm_stride() * sizeof(int8_t), nr() * sizeof(int8_t), a_offset() * sizeof(uint8_t), zero_pointer, &quantization_params); @@ -2182,12 +2405,12 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80); - EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80); - EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j])) + EXPECT_LE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmax()) - 0x80); + EXPECT_GE(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(qmin()) - 0x80); + EXPECT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * nr() + j % nr()]), int32_t(c_ref[i * n() + j])) << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j]) << " (accumulator = " << acc[i * n() + j] - << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " + << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point); } @@ -2212,7 +2435,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n()); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2258,14 +2481,14 @@ void GemmMicrokernelTester::Test( gemm_minmax(m(), n(), k() * sizeof(xnn_bfloat16), a.data(), a_stride() * sizeof(xnn_bfloat16), packed_w.data(), - c.data(), cm_stride() * sizeof(xnn_bfloat16), cn_stride() * sizeof(xnn_bfloat16), + c.data(), cm_stride() * sizeof(xnn_bfloat16), nr() * sizeof(xnn_bfloat16), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 3.0e-2f)) << "at " << i << ", " << j << ": Mr x Nr x Kr = " << mr() << " x " << nr() @@ -2292,7 +2515,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n()); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2339,15 +2562,15 @@ void GemmMicrokernelTester::Test( gemm_minmax(m(), n(), k() * sizeof(xnn_float16), a.data(), a_stride() * sizeof(xnn_float16), packed_w.data(), - c.data(), cm_stride() * sizeof(xnn_float16), cn_stride() * sizeof(xnn_float16), + c.data(), cm_stride() * sizeof(xnn_float16), nr() * sizeof(xnn_float16), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f)) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -2369,7 +2592,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( ks() * packed_k() * packed_n() + packed_n()); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(xnn_float16)); xnnpack::Buffer im2col(mr() * ks()); @@ -2451,23 +2674,23 @@ void GemmMicrokernelTester::Test( igemm_minmax( m(), n(), k() * sizeof(xnn_float16), ks() * mr() * sizeof(void*), reinterpret_cast(im2col.data()), packed_w.data(), - c.data(), cm_stride() * sizeof(xnn_float16), cn_stride() * sizeof(xnn_float16), + c.data(), cm_stride() * sizeof(xnn_float16), nr() * sizeof(xnn_float16), a_offset() * sizeof(xnn_float16), zero_pointer, ¶ms); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); - EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f)) + EXPECT_NEAR(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f)) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << (float)c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); } } @@ -2490,7 +2713,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w(packed_n() * packed_k() + packed_n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2535,18 +2758,18 @@ void GemmMicrokernelTester::Test( ppmm_minmax(m(), n(), k() * sizeof(float), a.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -2569,7 +2792,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w(packed_n() * packed_k() + packed_n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2598,18 +2821,18 @@ void GemmMicrokernelTester::Test( gemm(m(), n(), k() * sizeof(float), a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), nullptr); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -2632,7 +2855,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w(packed_n() * packed_k() + packed_n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2661,22 +2884,22 @@ void GemmMicrokernelTester::Test( gemm_relu(m(), n(), k() * sizeof(float), a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), nullptr); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], 0.0f) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -2700,7 +2923,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w(packed_n() * packed_k() + packed_n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2748,26 +2971,26 @@ void GemmMicrokernelTester::Test( gemm_minmax(m(), n(), k() * sizeof(float), a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -2787,7 +3010,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float)); xnnpack::Buffer b(n() * k()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2828,28 +3051,28 @@ void GemmMicrokernelTester::Test( gemm_minmax(m(), n(), k() * sizeof(float), a.data(), a_stride() * sizeof(float), b.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } @@ -2879,7 +3102,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer scale(n()); xnnpack::Buffer packed_w( packed_n() * packed_k_bytes + packed_n() * sizeof(float) * 2); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -2949,26 +3172,26 @@ void GemmMicrokernelTester::Test( gemm_minmax(m(), n(), k() * sizeof(float), // Note KC measured in bytes of input a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5, std::abs(c_ref[i * n() + j]) * 1.0e-6)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -2993,7 +3216,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer scale(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(float) * 2); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -3033,14 +3256,14 @@ void GemmMicrokernelTester::Test( gemm(m(), n(), k() * sizeof(float), // Note KC measured in bytes of input a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), nullptr); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], 0.1f); } @@ -3050,11 +3273,11 @@ void GemmMicrokernelTester::Test( for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -3079,7 +3302,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer scale(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(float) * 2); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -3118,22 +3341,22 @@ void GemmMicrokernelTester::Test( gemm_relu(m(), n(), k() * sizeof(float), // Note KC measured in bytes of input a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), nullptr); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], 0.0f) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -3159,7 +3382,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer scale(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(float) * 2); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); for (size_t iteration = 0; iteration < iterations(); iteration++) { @@ -3217,26 +3440,26 @@ void GemmMicrokernelTester::Test( gemm_minmax(m(), n(), k() * sizeof(float), // Note KC measured in bytes of input a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5, std::abs(c_ref[i * n() + j]) * 1.0e-6)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -3260,7 +3483,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k()); // no packed_n() - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer acc(mr() * packed_n()); @@ -3305,27 +3528,27 @@ void GemmMicrokernelTester::Test( gemminc(m(), n(), k() * sizeof(float), a.data(), a_stride() * sizeof(float), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), acc.data(), ¶ms); // Validate micro-kernel outputs. for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k(); } } @@ -3346,7 +3569,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( ks() * packed_k() * packed_n() + packed_n()); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(float)); xnnpack::Buffer im2col(mr() * ks()); @@ -3406,18 +3629,18 @@ void GemmMicrokernelTester::Test( igemm( m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*), im2col.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), a_offset() * sizeof(float), zero_pointer, nullptr); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); } } @@ -3438,7 +3661,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( ks() * packed_k() * packed_n() + packed_n()); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(float)); xnnpack::Buffer im2col(mr() * ks()); @@ -3498,22 +3721,22 @@ void GemmMicrokernelTester::Test( igemm_relu( m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*), im2col.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), a_offset() * sizeof(float), zero_pointer, nullptr); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], 0.0f) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); } } @@ -3535,7 +3758,7 @@ void GemmMicrokernelTester::Test( xnnpack::Buffer packed_w( ks() * packed_k() * packed_n() + packed_n()); xnnpack::Buffer bias(n()); - xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1); + xnnpack::Buffer c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * nr() + (n() - 1) % nr() + 1); xnnpack::Buffer c_ref(m() * n()); xnnpack::Buffer junk(k() + XNN_EXTRA_BYTES / sizeof(float)); xnnpack::Buffer im2col(mr() * ks()); @@ -3610,26 +3833,26 @@ void GemmMicrokernelTester::Test( igemm_minmax( m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*), im2col.data(), packed_w.data(), - c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float), + c.data(), cm_stride() * sizeof(float), nr() * sizeof(float), a_offset() * sizeof(float), zero_pointer, ¶ms); for (size_t i = 0; i < m(); i++) { for (size_t j = 0; j < n(); j++) { - EXPECT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max) + EXPECT_LE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_max) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); - EXPECT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min) + EXPECT_GE(c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_min) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); EXPECT_NEAR( - c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], + c[i * cm_stride() + (j / nr()) * nr() + j % nr()], c_ref[i * n() + j], std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f)) << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j] - << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << ", optimized = " << c[i * cm_stride() + (j / nr()) * nr() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks(); } } diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index f402a590c266..423ed1887277 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -17,7 +17,6 @@ #include #include -#include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/pack.h" @@ -26,98 +25,80 @@ class GemmMicrokernelTester { public: - GemmMicrokernelTester clone() const { - return *this; - } + GemmMicrokernelTester clone() const { return *this; } GemmMicrokernelTester& mr(size_t mr) { this->mr_ = mr; return *this; } - size_t mr() const { - return this->mr_; - } + size_t mr() const { return this->mr_; } GemmMicrokernelTester& nr(size_t nr) { this->nr_ = nr; return *this; } - size_t nr() const { - return this->nr_; - } - + size_t nr() const { return this->nr_; } GemmMicrokernelTester& kr(size_t kr) { this->kr_ = kr; return *this; } - size_t kr() const { - return this->kr_; - } + size_t kr() const { return this->kr_; } GemmMicrokernelTester& sr(size_t sr) { this->sr_ = sr; return *this; } - size_t sr() const { - return this->sr_; - } + size_t sr() const { return this->sr_; } GemmMicrokernelTester& m(size_t m) { this->m_ = m; return *this; } - size_t m() const { - return this->m_; - } + size_t m() const { return this->m_; } GemmMicrokernelTester& n(size_t n) { this->n_ = n; return *this; } - size_t n() const { - return this->n_; - } + size_t n() const { return this->n_; } GemmMicrokernelTester& k(size_t k) { this->k_ = k; return *this; } - size_t k() const { - return this->k_; - } + size_t k() const { return this->k_; } GemmMicrokernelTester& ks(size_t ks) { this->ks_ = ks; return *this; } - size_t ks() const { - return this->ks_; - } + size_t ks() const { return this->ks_; } inline GemmMicrokernelTester& bl(size_t bl) { this->bl_ = bl; return *this; } - inline size_t bl() const { - return this->bl_; - } + inline size_t bl() const { return this->bl_; } - size_t packed_k() const { - return round_up_po2(k(), kr() * sr()); - } + size_t packed_k() const { return round_up_po2(k(), kr() * sr()); } + + size_t packed_n() const { return round_up(n(), nr()); } - size_t packed_n() const { - return round_up(n(), nr()); + bool unsigned_inputs() const { return this->unsigned_inputs_; } + + GemmMicrokernelTester& unsigned_inputs(bool unsigned_inputs) { + this->unsigned_inputs_ = unsigned_inputs; + return *this; } GemmMicrokernelTester& a_stride(size_t a_stride) { @@ -135,16 +116,9 @@ class GemmMicrokernelTester { } size_t cm_stride() const { - return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_; - } - - GemmMicrokernelTester& cn_stride(size_t cn_stride) { - this->cn_stride_ = cn_stride; - return *this; - } - - size_t cn_stride() const { - return this->cn_stride_ == 0 ? nr() : this->cn_stride_; + return this->cm_stride_ == 0 + ? nr() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 + : this->cm_stride_; } GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) { @@ -152,81 +126,63 @@ class GemmMicrokernelTester { return *this; } - uint8_t a_zero_point() const { - return this->a_zero_point_; - } + uint8_t a_zero_point() const { return this->a_zero_point_; } GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) { this->b_zero_point_ = b_zero_point; return *this; } - uint8_t b_zero_point() const { - return this->b_zero_point_; - } + uint8_t b_zero_point() const { return this->b_zero_point_; } GemmMicrokernelTester& qmin(uint8_t qmin) { this->qmin_ = qmin; return *this; } - uint8_t qmin() const { - return this->qmin_; - } + uint8_t qmin() const { return this->qmin_; } GemmMicrokernelTester& qmax(uint8_t qmax) { this->qmax_ = qmax; return *this; } - uint8_t qmax() const { - return this->qmax_; - } + uint8_t qmax() const { return this->qmax_; } GemmMicrokernelTester& a_offset(size_t a_offset) { this->a_offset_ = a_offset; return *this; } - size_t a_offset() const { - return this->a_offset_; - } + size_t a_offset() const { return this->a_offset_; } GemmMicrokernelTester& zero_index(size_t zero_index) { this->zero_index_ = zero_index; return *this; } - size_t zero_index() const { - return this->zero_index_; - } + size_t zero_index() const { return this->zero_index_; } GemmMicrokernelTester& iterations(size_t iterations) { this->iterations_ = iterations; return *this; } - size_t iterations() const { - return this->iterations_; - } + size_t iterations() const { return this->iterations_; } GemmMicrokernelTester& known_nc_mod_nr(bool known_nc_mod_nr) { this->known_nc_mod_nr_ = known_nc_mod_nr; return *this; } - bool known_nc_mod_nr() const { - return known_nc_mod_nr_; - } + bool known_nc_mod_nr() const { return known_nc_mod_nr_; } GemmMicrokernelTester& relu(bool relu) { this->relu_ = relu; return *this; } - bool relu() const { - return relu_; - } + bool relu() const { return relu_; } GemmMicrokernelTester& mr_packed(size_t mr_packed) { this->mr_packed_ = mr_packed; @@ -240,164 +196,129 @@ class GemmMicrokernelTester { return this->mr_packed_; } - size_t nc_mod_nr() const { - return known_nc_mod_nr() ? n() % nr() : SIZE_MAX; - } - - void Test( - xnn_qd8_f16_qc8w_igemm_ukernel_fn igemm, - xnn_init_f16_minmax_params_fn init_params, - xnn_pack_qs8_igemm_fn pack) const; - - void Test( - xnn_qd8_f32_qc8w_igemm_ukernel_fn gemm, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_qs8_igemm_fn pack) const; - - void Test( - xnn_qu8_gemm_minmax_ukernel_fn gemm, - xnn_init_qu8_conv_minmax_params_fn init_params, - xnn_pack_qu8_gemm_fn pack, - xnn_qu8_requantize_fn requantize) const; - - void Test( - xnn_qu8_igemm_minmax_ukernel_fn igemm, - xnn_init_qu8_conv_minmax_params_fn init_params, - xnn_pack_qu8_igemm_fn pack, - xnn_qu8_requantize_fn requantize); - - void Test( - xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm, - xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack, - xnn_qs8_requantize_fn requantize) const; - - void Test( - xnn_qs8_qc8w_igemm_minmax_ukernel_fn igemm, - xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, - xnn_pack_qs8_igemm_fn pack, - xnn_qs8_requantize_fn requantize) const; - - void Test( - xnn_qs8_gemm_minmax_ukernel_fn gemm, - xnn_init_qs8_conv_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack, - xnn_qs8_requantize_fn requantize) const; - - void Test( - xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm, - xnn_init_f16_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack) const; - - void Test( - xnn_qd8_f32_qc8w_gemm_ukernel_fn gemm, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_qs8_gemm_fn pack) const; - - void Test( - xnn_qd8_f16_qc4w_gemm_ukernel_fn gemm, - xnn_init_f16_qc4w_minmax_params_fn init_params, - xnn_pack_qs8_qc4w_gemm_fn pack) const; - - void Test( - xnn_qd8_f16_qb4w_gemm_ukernel_fn gemm, - xnn_init_f16_qb4w_minmax_params_fn init_params, - xnn_pack_qs8_qb4w_gemm_fn pack) const; - - void Test( - xnn_qd8_f32_qc4w_gemm_ukernel_fn gemm, - xnn_init_f32_qc4w_minmax_params_fn init_params, - xnn_pack_qs8_qc4w_gemm_fn pack) const; - - void Test( - xnn_qd8_f32_qb4w_gemm_ukernel_fn gemm, - xnn_init_f32_qb4w_minmax_params_fn init_params, - xnn_pack_qs8_qb4w_gemm_fn pack) const; - - void Test( - xnn_qs8_igemm_minmax_ukernel_fn igemm, - xnn_init_qs8_conv_minmax_params_fn init_params, - xnn_pack_qs8_igemm_fn pack, - xnn_qs8_requantize_fn requantize) const; - - void Test( - xnn_bf16_gemm_minmax_ukernel_fn gemm_minmax, - xnn_init_bf16_minmax_params_fn init_params, - xnn_pack_f16_gemm_fn pack) const; - - void Test( - xnn_f16_gemm_minmax_ukernel_fn gemm_minmax, - xnn_init_f16_minmax_params_fn init_params, - xnn_pack_f16_gemm_fn pack) const; - - void Test( - xnn_f16_igemm_minmax_ukernel_fn igemm_minmax, - xnn_init_f16_minmax_params_fn init_params, - xnn_pack_f16_igemm_fn pack) const; - - void Test( - xnn_f32_ppmm_minmax_ukernel_fn ppmm_minmax, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_f32_gemm_fn pack) const; - - void Test( - xnn_f32_gemm_ukernel_fn gemm, - xnn_pack_f32_gemm_fn pack) const; - - void Test( - xnn_f32_gemm_relu_ukernel_fn gemm_relu, - xnn_pack_f32_gemm_fn pack) const; - - void Test( - xnn_f32_gemm_minmax_ukernel_fn gemm_minmax, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_f32_gemm_fn pack) const; - - void Test( - xnn_f32_gemm_goi_minmax_ukernel_fn gemm_minmax, - xnn_init_f32_minmax_params_fn init_params) const; - - void Test( - xnn_f32_qc4w_gemm_minmax_ukernel_fn gemm_minmax, - xnn_init_f32_qc4w_minmax_params_fn init_params, - xnn_pack_f32_qc4w_gemm_fn pack) const; - - void Test( - xnn_f32_qc8w_gemm_ukernel_fn gemm, - xnn_pack_f32_qs8w_gemm_fn pack) const; - - void Test( - xnn_f32_qc8w_gemm_relu_ukernel_fn gemm_relu, - xnn_pack_f32_qs8w_gemm_fn pack) const; - - void Test( - xnn_f32_qc8w_gemm_minmax_ukernel_fn gemm_minmax, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_f32_qs8w_gemm_fn pack) const; - - void Test( - xnn_f32_gemminc_minmax_ukernel_fn gemminc, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_f32_gemminc_fn pack) const; - - void Test( - xnn_f32_igemm_ukernel_fn igemm, - xnn_pack_f32_igemm_fn pack) const; - - void Test( - xnn_f32_igemm_relu_ukernel_fn igemm_relu, - xnn_pack_f32_igemm_fn pack) const; - - void Test( - xnn_f32_igemm_minmax_ukernel_fn igemm_minmax, - xnn_init_f32_minmax_params_fn init_params, - xnn_pack_f32_igemm_fn pack) const; + size_t nc_mod_nr() const { return known_nc_mod_nr() ? n() % nr() : SIZE_MAX; } + + void Test(xnn_qd8_f16_qc8w_igemm_ukernel_fn igemm, + xnn_init_f16_minmax_params_fn init_params, + xnn_pack_qs8_igemm_fn pack) const; + + void Test(xnn_qd8_f32_qc8w_igemm_ukernel_fn gemm, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_qs8_igemm_fn pack) const; + + void Test(xnn_qu8_gemm_minmax_ukernel_fn gemm, + xnn_init_qu8_conv_minmax_params_fn init_params, + xnn_pack_qu8_gemm_fn pack, xnn_qu8_requantize_fn requantize) const; + + void Test(xnn_qu8_igemm_minmax_ukernel_fn igemm, + xnn_init_qu8_conv_minmax_params_fn init_params, + xnn_pack_qu8_igemm_fn pack, xnn_qu8_requantize_fn requantize); + + void Test(xnn_qs8_qc8w_gemm_minmax_ukernel_fn gemm, + xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, + xnn_pack_qs8_gemm_fn pack, xnn_qs8_requantize_fn requantize) const; + + void Test(xnn_qs8_qc8w_igemm_minmax_ukernel_fn igemm, + xnn_init_qs8_qc8w_conv_minmax_params_fn init_params, + xnn_pack_qs8_igemm_fn pack, xnn_qs8_requantize_fn requantize) const; + + void Test(xnn_qs8_gemm_minmax_ukernel_fn gemm, + xnn_init_qs8_conv_minmax_params_fn init_params, + xnn_pack_qs8_gemm_fn pack, xnn_qs8_requantize_fn requantize) const; + + void Test(xnn_qd8_f16_qc8w_gemm_ukernel_fn gemm, + xnn_init_f16_minmax_params_fn init_params, + xnn_pack_qs8_gemm_fn pack) const; + + void Test(xnn_qd8_f32_qc8w_gemm_ukernel_fn gemm, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_qs8_gemm_fn pack) const; + + void Test(xnn_qd8_f16_qc4w_gemm_ukernel_fn gemm, + xnn_init_f16_qc4w_minmax_params_fn init_params, + xnn_pack_qs8_qc4w_gemm_fn pack) const; + + void Test(xnn_qd8_f16_qb4w_gemm_ukernel_fn gemm, + xnn_init_f16_qb4w_minmax_params_fn init_params, + xnn_pack_qs8_qb4w_gemm_fn pack) const; + + void Test(xnn_qd8_f32_qc4w_gemm_ukernel_fn gemm, + xnn_init_f32_qc4w_minmax_params_fn init_params, + xnn_pack_qs8_qc4w_gemm_fn pack) const; + + void Test(xnn_qd8_f32_qb4w_gemm_ukernel_fn gemm, + xnn_init_f32_qb4w_minmax_params_fn init_params, + xnn_pack_qs8_qb4w_gemm_fn pack) const; + + void Test(xnn_qs8_igemm_minmax_ukernel_fn igemm, + xnn_init_qs8_conv_minmax_params_fn init_params, + xnn_pack_qs8_igemm_fn pack, xnn_qs8_requantize_fn requantize) const; + + void Test(xnn_bf16_gemm_minmax_ukernel_fn gemm_minmax, + xnn_init_bf16_minmax_params_fn init_params, + xnn_pack_f16_gemm_fn pack) const; + + void Test(xnn_f16_gemm_minmax_ukernel_fn gemm_minmax, + xnn_init_f16_minmax_params_fn init_params, + xnn_pack_f16_gemm_fn pack) const; + + void Test(xnn_f16_igemm_minmax_ukernel_fn igemm_minmax, + xnn_init_f16_minmax_params_fn init_params, + xnn_pack_f16_igemm_fn pack) const; + + void Test(xnn_f32_ppmm_minmax_ukernel_fn ppmm_minmax, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_f32_gemm_fn pack) const; + + void Test(xnn_f32_gemm_ukernel_fn gemm, xnn_pack_f32_gemm_fn pack) const; + + void Test(xnn_f32_gemm_relu_ukernel_fn gemm_relu, + xnn_pack_f32_gemm_fn pack) const; + + void Test(xnn_f32_gemm_minmax_ukernel_fn gemm_minmax, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_f32_gemm_fn pack) const; + + void Test(xnn_f32_gemm_goi_minmax_ukernel_fn gemm_minmax, + xnn_init_f32_minmax_params_fn init_params) const; + + void Test(xnn_f32_qc4w_gemm_minmax_ukernel_fn gemm_minmax, + xnn_init_f32_qc4w_minmax_params_fn init_params, + xnn_pack_f32_qc4w_gemm_fn pack) const; + + void Test(xnn_f32_qc8w_gemm_ukernel_fn gemm, + xnn_pack_f32_qs8w_gemm_fn pack) const; + + void Test(xnn_f32_qc8w_gemm_relu_ukernel_fn gemm_relu, + xnn_pack_f32_qs8w_gemm_fn pack) const; + + void Test(xnn_f32_qc8w_gemm_minmax_ukernel_fn gemm_minmax, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_f32_qs8w_gemm_fn pack) const; + + void Test(xnn_f32_gemminc_minmax_ukernel_fn gemminc, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_f32_gemminc_fn pack) const; + + void Test(xnn_f32_igemm_ukernel_fn igemm, xnn_pack_f32_igemm_fn pack) const; + + void Test(xnn_f32_igemm_relu_ukernel_fn igemm_relu, + xnn_pack_f32_igemm_fn pack) const; + + void Test(xnn_f32_igemm_minmax_ukernel_fn igemm_minmax, + xnn_init_f32_minmax_params_fn init_params, + xnn_pack_f32_igemm_fn pack) const; void Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_fn gemm, xnn_init_f32_minmax_params_fn init_minmax_params, xnn_pack_weights_and_biases_fn pack, xnn_packed_stride_weights_and_biases_fn packed_stride); + void Test_QP8F32QC8W(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_fn gemm, + xnn_init_f32_minmax_params_fn init_minmax_params, + xnn_pack_weights_and_biases_fn pack, + xnn_packed_stride_weights_and_biases_fn packed_stride); + void Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_fn gemm, xnn_init_f32_qb4w_minmax_params_fn init_minmax_params, xnn_pack_weights_and_biases_fn pack, @@ -413,9 +334,9 @@ class GemmMicrokernelTester { size_t k_{1}; size_t ks_{1}; size_t bl_{SIZE_MAX}; + bool unsigned_inputs_{false}; size_t a_stride_{0}; size_t cm_stride_{0}; - size_t cn_stride_{0}; uint8_t a_zero_point_{127}; uint8_t b_zero_point_{127}; uint8_t qmin_{0}; @@ -428,14 +349,12 @@ class GemmMicrokernelTester { size_t mr_packed_{0}; }; -enum class LoopStepType { - Linear, - NextPrime -}; +enum class LoopStepType { Linear, NextPrime }; struct LoopParams { LoopParams() = default; - explicit LoopParams(size_t from, size_t to, size_t step, LoopStepType step_type) + explicit LoopParams(size_t from, size_t to, size_t step, + LoopStepType step_type) : is_set(true), from(from), to(to), step(step), step_type(step_type) {} bool is_set = false; size_t from = 1; @@ -450,7 +369,8 @@ struct LoopParams { case LoopStepType::NextPrime: return xnnpack::NextPrime(n + step); default: - std::cerr << "Unknown loop step type " << static_cast(step_type) << std::endl; + std::cerr << "Unknown loop step type " << static_cast(step_type) + << std::endl; std::abort(); } } @@ -466,27 +386,33 @@ struct GemmTestParams { isa_check(isa_check) {} // Setters for the loops over `k`, `m`, and `n`. - GemmTestParams& loop_k(size_t from, size_t to, size_t step = 1, LoopStepType step_type = LoopStepType::NextPrime) { + GemmTestParams& loop_k(size_t from, size_t to, size_t step = 1, + LoopStepType step_type = LoopStepType::NextPrime) { loop_k_ = LoopParams(from, to, step, step_type); return *this; } - GemmTestParams& loop_m(size_t from, size_t to, size_t step = 1, LoopStepType step_type = LoopStepType::Linear) { + GemmTestParams& loop_m(size_t from, size_t to, size_t step = 1, + LoopStepType step_type = LoopStepType::Linear) { loop_m_ = LoopParams(from, to, step, step_type); return *this; } - GemmTestParams& loop_n(size_t from, size_t to, size_t step = 1, LoopStepType step_type = LoopStepType::NextPrime) { + GemmTestParams& loop_n(size_t from, size_t to, size_t step = 1, + LoopStepType step_type = LoopStepType::NextPrime) { loop_n_ = LoopParams(from, to, step, step_type); return *this; } - GemmTestParams& loop_zi(size_t from, size_t to, size_t step = 1, LoopStepType step_type = LoopStepType::Linear) { + GemmTestParams& loop_zi(size_t from, size_t to, size_t step = 1, + LoopStepType step_type = LoopStepType::Linear) { loop_zi_ = LoopParams(from, to, step, step_type); return *this; } - GemmTestParams& loop_bzp(size_t from, size_t to, size_t step = 1, LoopStepType step_type = LoopStepType::Linear) { + GemmTestParams& loop_bzp(size_t from, size_t to, size_t step = 1, + LoopStepType step_type = LoopStepType::Linear) { loop_bzp_ = LoopParams(from, to, step, step_type); return *this; } - GemmTestParams& loop_bl(size_t from, size_t to, size_t step = 1, LoopStepType step_type = LoopStepType::Linear) { + GemmTestParams& loop_bl(size_t from, size_t to, size_t step = 1, + LoopStepType step_type = LoopStepType::Linear) { loop_bl_ = LoopParams(from, to, step, step_type); return *this; } diff --git a/test/max-pooling-2d.cc b/test/max-pooling-2d.cc index 65a4ff58ca50..b09fab5c8e44 100644 --- a/test/max-pooling-2d.cc +++ b/test/max-pooling-2d.cc @@ -23,6 +23,7 @@ #include "xnnpack/requantization.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class MaxPooling2DTestBase : public ::testing::Test { protected: @@ -371,7 +372,7 @@ TEST_F(MaxPooling2DTestQS8, matches_operator_api) stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -442,7 +443,7 @@ TEST_F(MaxPooling2DTestQU8, matches_operator_api) stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -508,7 +509,7 @@ TEST_F(MaxPooling2DTestF16, matches_operator_api) stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -573,7 +574,7 @@ TEST_F(MaxPooling2DTestF32, matches_operator_api) stride_width, dilation_height, dilation_width, output_min, output_max, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -636,7 +637,7 @@ TEST_F(MaxPooling2DTestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -708,7 +709,7 @@ TEST_F(MaxPooling2DTestF32, ReshapeWithPadding) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -779,7 +780,7 @@ TEST_F(MaxPooling2DTestF32, ReshapeWithPaddingAndDilation) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/memory-planner.cc b/test/memory-planner.cc index 7b1d707d1f6a..d6343049b7b7 100644 --- a/test/memory-planner.cc +++ b/test/memory-planner.cc @@ -11,6 +11,7 @@ #include "xnnpack/memory-planner.h" #include "xnnpack/node-type.h" #include "xnnpack/subgraph.h" +#include "runtime-flags.h" #include "runtime-tester.h" #include "subgraph-tester.h" @@ -212,7 +213,7 @@ TEST(MemoryPlanner, LeakyReluInPlaceAfterConv) { input_id, filter_id, bias_id, conv_out) .AddLeakyRelu(1.0f, conv_out, leaky_relu_out) .AddClamp(0.0f, 1.0f, leaky_relu_out, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -256,7 +257,7 @@ TEST(MemoryPlanner, LeakyReluWithTwoConsumersCannotBeInPlace) { .AddLeakyRelu(1.0f, conv_out, leaky_relu_out) .AddClamp(0.0f, 1.0f, leaky_relu_out, output_id) .AddClamp(1.0f, 2.0f, leaky_relu_out, output_id2); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -302,7 +303,7 @@ TEST(MemoryPlanner, HardSwishAndLeakyReluInPlaceAfterConv) { .AddLeakyRelu(1.0f, conv_out, leaky_relu_out) .AddHardSwish(leaky_relu_out, hard_swish_out) .AddClamp(0.0f, 1.0f, hard_swish_out, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -327,7 +328,7 @@ TEST(MemoryPlanner, ExternalInputsCannotBeInPlace) { .AddOutputTensorF32({1, 3, 3, 3}, output_id) .AddLeakyRelu(1.0f, input_id, leaky_relu_out) .AddClamp(0.0f, 1.0f, leaky_relu_out, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -353,7 +354,7 @@ TEST(MemoryPlanner, PersistentValuesCannotReuseInternalValues) { .AddClamp(0.0f, 1.0f, input_id, clamp_out_id) .AddLeakyRelu(1.0f, clamp_out_id, leaky_relu_out_id) .AddClamp(0.0f, 1.0f, leaky_relu_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -379,7 +380,7 @@ TEST(MemoryPlanner, CannotReuseStaticValues) { .AddOutputTensorF32({1, 3, 3, 3}, output_id) .AddClamp(0.0f, 1.0f, static_id, clamp_out_id) .AddLeakyRelu(1.0f, clamp_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -420,7 +421,7 @@ TEST(MemoryPlanner, Add2WithLHSConstantInPlace) { input_id, filter_id, bias_id, conv_out) .AddAddition(add_constant_input_id, conv_out, add_out_id) .AddLeakyRelu(1.0f, add_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -462,7 +463,7 @@ TEST(MemoryPlanner, Add2WithLHSConstant) { input_id, filter_id, bias_id, conv_out) .AddAddition(add_constant_input_id, conv_out, add_out_id) .AddLeakyRelu(1.0f, add_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -504,7 +505,7 @@ TEST(MemoryPlanner, Add2WithRHSConstantInPlace) { input_id, filter_id, bias_id, conv_out) .AddAddition(conv_out, add_constant_input_id, add_out_id) .AddLeakyRelu(1.0f, add_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -545,7 +546,7 @@ TEST(MemoryPlanner, Mul2WithLHSConstant) { input_id, filter_id, bias_id, conv_out) .AddMultiply(mul_constant_input_id, conv_out, mul_out_id) .AddLeakyRelu(1.0f, mul_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -586,7 +587,7 @@ TEST(MemoryPlanner, Mul2WithRHSConstant) { input_id, filter_id, bias_id, conv_out) .AddMultiply(conv_out, mul_constant_input_id, mul_out_id) .AddLeakyRelu(1.0f, mul_out_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -639,7 +640,7 @@ TEST(MemoryPlanner, Add2WithImplicitBroadcast) { input2_id, filter_id, bias_id, conv_out) .AddAddition(hard_swish_out, conv_out, add_out) .AddLeakyRelu(1.0f, add_out, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -707,7 +708,7 @@ TEST(MemoryPlanner, Add2WithInputMultipleConsumers) { /*output_id=*/max_pooling_2d_out) .AddAddition(conv_out, max_pooling_2d_out, add_out) .AddLeakyRelu(1.0f, add_out, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); @@ -749,7 +750,7 @@ TEST(MemoryPlanner, FullyConnectedDynamicFilterDynamicBias) { .AddConstantPad({1}, {0}, 0.0f, input3_id, bias_id) .AddFullyConnected(input1_id, filter_id, bias_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); xnn_operator_data* fc_opdata = &runtime->opdata[2]; @@ -787,7 +788,7 @@ TEST(MemoryPlanner, FullyConnectedDynamicFilterStaticBias) { .AddConstantPad({0, 0, 0, 1}, {0, 0, 0, 0}, 0.0f, input2_id, filter_id) .AddFullyConnected(input1_id, filter_id, bias_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); xnn_operator_data* fc_opdata = &runtime->opdata[1]; @@ -823,7 +824,7 @@ TEST(MemoryPlanner, FullyConnectedDynamicFilterNoBias) { .AddConstantPad({0, 0, 0, 1}, {0, 0, 0, 0}, 0.0f, input2_id, filter_id) .AddFullyConnected(input1_id, filter_id, bias_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); xnn_operator_data* fc_opdata = &runtime->opdata[1]; @@ -860,7 +861,7 @@ TEST(MemoryPlanner, FullyConnectedStaticFilterDynamicBias) { .AddConstantPad({1}, {0}, 0.0f, input3_id, bias_id) .AddFullyConnected(input1_id, filter_id, bias_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); xnn_operator_data* fc_opdata = &runtime->opdata[1]; @@ -891,7 +892,7 @@ TEST(MemoryPlanner, FullyConnectedExternalFilterExternalBias) { .AddOutputTensorF32({2, 3, 3, 2}, output_id) .AddFullyConnected(input_id, filter_id, bias_id, output_id); - tester.CreateRuntime(); + tester.CreateRuntime(xnn_test_runtime_flags()); tester.SetupRuntime(); xnn_runtime_t runtime = tester.Runtime(); xnn_operator_data* fc_opdata = &runtime->opdata[0]; diff --git a/test/microkernel-utils.cc b/test/microkernel-utils.cc index 58680c1205d3..79ac5971bbee 100644 --- a/test/microkernel-utils.cc +++ b/test/microkernel-utils.cc @@ -51,12 +51,10 @@ TEST(GEMM_BEST_NC, min_tiles_per_thread) { << ", num_threads=" << num_threads; } - // Verify that the next-largest `nc` would indeed be too large. - if (nc < nr) { - const size_t num_tiles_n = divide_round_up(n, nc + nr); - const size_t num_tiles = num_groups * num_tiles_m * num_tiles_n; - EXPECT_GT(min_num_tiles, num_tiles) - << "Computed `nc` is too conservative, num_groups=" << num_groups + // Verify that the next-smallest `nc` would increase the number of tiles. + if (nr < nc && nc < n) { + EXPECT_NE(divide_round_up(n, nc), divide_round_up(n, nc - nr)) + << "Failed to get minimal `nc` for num_groups=" << num_groups << ", m=" << m << ", n=" << n << ", " << "mr=" << mr << " , " << "nr=" << nr << " , " << "nc=" << nc << ", num_threads=" << num_threads; diff --git a/test/packw-microkernel-tester.h b/test/packw-microkernel-tester.h index 1e30db3214e6..348e7049b16b 100644 --- a/test/packw-microkernel-tester.h +++ b/test/packw-microkernel-tester.h @@ -107,7 +107,7 @@ class PackWMicrokernelTester { } void Test(xnn_qs8_packw_gemm_goi_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(int8_t) + n() * k()); + xnnpack::Buffer weights(n() * k()); xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); @@ -158,7 +158,7 @@ class PackWMicrokernelTester { } void Test(xnn_qs8_packw_gemm_gio_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(int8_t) + n() * k()); + xnnpack::Buffer weights(n() * k()); xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); @@ -209,7 +209,7 @@ class PackWMicrokernelTester { } void Test(xnn_x8_packw_gemm_goi_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(int8_t) + n() * k()); + xnnpack::Buffer weights(n() * k()); xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); @@ -245,7 +245,7 @@ class PackWMicrokernelTester { } void Test(xnn_x8_packw_gemm_gio_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(int8_t) + n() * k()); + xnnpack::Buffer weights(n() * k()); xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); @@ -284,7 +284,9 @@ class PackWMicrokernelTester { xnnpack::ReplicableRandomDevice rng; auto i32rng = std::bind(std::uniform_int_distribution(-10000, 10000), std::ref(rng)); - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(int8_t) + n() * k()); + const size_t k2 = round_up_po2(k(), 2); // Round up to byte aligned rows + + xnnpack::Buffer weights(n() * k2 / 2); xnnpack::Buffer bias(n()); xnnpack::Buffer packed_w( packed_n() * packed_k() + packed_n() * sizeof(uint32_t)); @@ -300,7 +302,7 @@ class PackWMicrokernelTester { const xnn_qs8_qc4w_packing_params packing_params = { 0 }; // Compute reference results. - xnn_pack_qs8_qc4w_gemm_goi_w(/*g=*/1, n(), k(), nr(), kr(), sr(), + xnn_pack_qs8_qc4w_gemm_goi_w(/*g=*/1, n(), k2, nr(), kr(), sr(), weights.data(), bias_data, /*scale=*/nullptr, @@ -308,7 +310,7 @@ class PackWMicrokernelTester { /*extra_bytes=*/0, &packing_params); // Call optimized micro-kernel. - packw(/*g=*/1, n(), k(), nr(), kr(), sr(), + packw(/*g=*/1, n(), k2, nr(), kr(), sr(), weights.data(), bias_data, /*scale=*/nullptr, packed_w.data(), /*extra_bytes=*/0, &packing_params); // Verify bias results. @@ -319,13 +321,13 @@ class PackWMicrokernelTester { } // Verify weights results. - // NOTE remainder KC is different so k() is used instead of packed_k() for loop - for (size_t ki = 0; ki < k(); ki++) { + // NOTE remainder KC is different so k2 is used instead of packed_k() for loop + for (size_t ki = 0; ki < k2; ki++) { for (size_t ni = 0; ni < (n()); ni++) { const size_t i = packed_n() * sizeof(int32_t) + ki * packed_n() + ni; if (packed_w_ref[i] != INT8_C(0x7B)) { // Allow pad to differ EXPECT_EQ((int32_t) packed_w[i], (int32_t) packed_w_ref[i]) - << "kr " << kr() << " of kc " << k() << " packed_k " << packed_k() << "\n" + << "kr " << kr() << " of kc " << k2 << " packed_k " << packed_k() << "\n" << "nr " << nr() << " of nc " << n() << " packed_n " << packed_n() << "\n" << "at n " << i << " of " << (int32_t) (packed_n() * packed_k() + packed_n() * sizeof(int32_t)); } @@ -334,7 +336,7 @@ class PackWMicrokernelTester { } void Test(xnn_x16_packw_gemm_goi_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(xnn_float16) + g() * n() * k()); + xnnpack::Buffer weights(g() * n() * k()); xnnpack::Buffer padded_weights(g() * n() * packed_k()); xnnpack::Buffer bias(g() * n()); xnnpack::Buffer packed_w( @@ -386,7 +388,7 @@ class PackWMicrokernelTester { } void Test(xnn_x32_packw_gemm_goi_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(uint32_t) + g() * n() * k()); + xnnpack::Buffer weights(g() * n() * k()); xnnpack::Buffer padded_weights(g() * n() * packed_k()); xnnpack::Buffer bias(g() * n()); xnnpack::Buffer packed_w( @@ -436,7 +438,7 @@ class PackWMicrokernelTester { } void Test(xnn_x32_packw_gemm_gio_ukernel_fn packw) const { - xnnpack::Buffer weights(XNN_EXTRA_BYTES / sizeof(uint32_t) + g() * n() * k()); + xnnpack::Buffer weights(g() * n() * k()); xnnpack::Buffer padded_weights(g() * n() * packed_k()); xnnpack::Buffer bias(g() * n()); xnnpack::Buffer packed_w( diff --git a/test/qd8-f16-qb4w-gemm-minmax.cc b/test/qd8-f16-qb4w-gemm-minmax.cc index 3eea94fdaf96..fbaaf883609a 100644 --- a/test/qd8-f16-qb4w-gemm-minmax.cc +++ b/test/qd8-f16-qb4w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -55,14 +56,6 @@ std::vector CreateTests1( .b_zero_point(8) .bl(32) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - .bl(32) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -120,6 +113,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x2__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -137,6 +131,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x4__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -154,6 +149,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -171,6 +167,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x2__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -188,6 +185,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x4__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -205,6 +203,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -222,6 +221,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar, xnn_init_f16_qb4w_minmax_scalar_params, @@ -240,6 +240,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -259,6 +260,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -278,6 +280,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -297,6 +300,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x16c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -316,6 +320,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -335,6 +340,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x16c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -354,6 +360,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x8c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -373,6 +380,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x16c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -392,6 +400,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x8c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -411,6 +420,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x16c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -430,6 +440,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x8c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -449,6 +460,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x16c4__neondotfp16arith, xnn_init_f16_qb4w_minmax_scalar_params, @@ -471,6 +483,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8c8__avx2, xnn_init_f16_qb4w_minmax_scalar_params, @@ -490,6 +503,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8c8__avx2, xnn_init_f16_qb4w_minmax_scalar_params, @@ -509,6 +523,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c8__avx2, xnn_init_f16_qb4w_minmax_scalar_params, @@ -528,6 +543,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x8c8__avx2, xnn_init_f16_qb4w_minmax_scalar_params, @@ -550,6 +566,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16__neonfp16arith_mlal_lane, xnn_init_f16_qb4w_minmax_scalar_params, @@ -569,6 +586,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x16__neonfp16arith_mlal_lane, xnn_init_f16_qb4w_minmax_scalar_params, @@ -588,6 +606,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x16__neonfp16arith_mlal_lane, xnn_init_f16_qb4w_minmax_scalar_params, @@ -607,6 +626,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x16__neonfp16arith_mlal_lane, xnn_init_f16_qb4w_minmax_scalar_params, @@ -626,6 +646,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x16__neonfp16arith_mlal_lane, xnn_init_f16_qb4w_minmax_scalar_params, @@ -645,6 +666,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -664,6 +686,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -683,6 +706,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -702,6 +726,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -721,6 +746,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -743,6 +769,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -762,6 +789,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -781,6 +809,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_1x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -800,6 +829,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -819,6 +849,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -838,6 +869,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_2x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -857,6 +889,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -876,6 +909,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -895,6 +929,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_3x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -914,6 +949,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -933,6 +969,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -952,6 +989,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_4x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -971,6 +1009,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -990,6 +1029,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1009,6 +1049,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_5x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1028,6 +1069,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1047,6 +1089,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1066,6 +1109,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_6x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1085,6 +1129,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_7x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1104,6 +1149,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_7x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1123,6 +1169,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_7x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1142,6 +1189,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1161,6 +1209,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, @@ -1180,6 +1229,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qb4w_gemm_minmax_ukernel_8x32c8__neoni8mm, xnn_init_f16_qb4w_minmax_scalar_params, diff --git a/test/qd8-f16-qc4w-gemm-minmax-2.cc b/test/qd8-f16-qc4w-gemm-minmax-2.cc index 37a32e67046c..0e05f9320e05 100644 --- a/test/qd8-f16-qc4w-gemm-minmax-2.cc +++ b/test/qd8-f16-qc4w-gemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -346,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -365,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -384,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -403,6 +382,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -425,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -444,6 +425,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -463,6 +445,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -485,6 +468,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -504,6 +488,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -523,6 +508,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -542,6 +528,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -561,6 +548,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -583,6 +571,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -602,6 +591,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -621,6 +611,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x16c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -640,6 +631,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -659,6 +651,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x16c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -681,6 +674,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16__neonfp16arith_mlal_lane, xnn_init_f16_qc4w_minmax_scalar_params, @@ -700,6 +694,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x16__neonfp16arith_mlal_lane, xnn_init_f16_qc4w_minmax_scalar_params, @@ -719,6 +714,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -738,6 +734,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x16__neonfp16arith_mlal_lane, xnn_init_f16_qc4w_minmax_scalar_params, @@ -757,6 +754,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -776,6 +774,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -798,6 +797,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -817,6 +817,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -836,6 +837,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -855,6 +857,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -877,6 +880,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -896,6 +900,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -915,6 +920,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -934,6 +940,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -956,6 +963,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -975,6 +983,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -994,6 +1003,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1013,6 +1023,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1035,6 +1046,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1054,6 +1066,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1076,6 +1089,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1095,6 +1109,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1114,6 +1129,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, diff --git a/test/qd8-f16-qc4w-gemm-minmax-3.cc b/test/qd8-f16-qc4w-gemm-minmax-3.cc index ff6262e66bce..6ae023f7d330 100644 --- a/test/qd8-f16-qc4w-gemm-minmax-3.cc +++ b/test/qd8-f16-qc4w-gemm-minmax-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -346,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -365,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -384,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -406,6 +385,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -425,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -444,6 +425,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -466,6 +448,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -485,6 +468,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -504,6 +488,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -523,6 +508,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -542,6 +528,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -561,6 +548,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -580,6 +568,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -602,6 +591,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x16c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -621,6 +611,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x16c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -640,6 +631,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -659,6 +651,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x16c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -681,6 +674,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -700,6 +694,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -719,6 +714,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -738,6 +734,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -757,6 +754,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -779,6 +777,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -798,6 +797,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -817,6 +817,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -836,6 +837,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -858,6 +860,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -877,6 +880,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -896,6 +900,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -915,6 +920,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -937,6 +943,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -956,6 +963,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -975,6 +983,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -997,6 +1006,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, diff --git a/test/qd8-f16-qc4w-gemm-minmax-4.cc b/test/qd8-f16-qc4w-gemm-minmax-4.cc index fedec22efdbd..a9d4a0676f45 100644 --- a/test/qd8-f16-qc4w-gemm-minmax-4.cc +++ b/test/qd8-f16-qc4w-gemm-minmax-4.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -346,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -365,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -384,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -406,6 +385,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -425,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -444,6 +425,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -463,6 +445,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -485,6 +468,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -504,6 +488,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -523,6 +508,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -542,6 +528,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -561,6 +548,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -580,6 +568,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -602,6 +591,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -621,6 +611,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -643,6 +634,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -662,6 +654,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x16__neonfp16arith_mlal_lane_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -684,6 +677,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -706,6 +700,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -725,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -744,6 +740,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -763,6 +760,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -782,6 +780,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -804,6 +803,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -823,6 +823,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -842,6 +843,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -861,6 +863,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -880,6 +883,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -902,6 +906,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -924,6 +929,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, @@ -943,6 +949,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, diff --git a/test/qd8-f16-qc4w-gemm-minmax.cc b/test/qd8-f16-qc4w-gemm-minmax.cc index 6f1f33120358..096b6834dda7 100644 --- a/test/qd8-f16-qc4w-gemm-minmax.cc +++ b/test/qd8-f16-qc4w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -346,6 +322,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -365,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -384,6 +362,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -403,6 +382,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -422,6 +402,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -441,6 +422,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -463,6 +445,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -482,6 +465,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -501,6 +485,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd, xnn_init_f16_qc4w_minmax_scalar_params, @@ -520,6 +505,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -539,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -558,6 +545,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -580,6 +568,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -599,6 +588,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -618,6 +608,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -637,6 +628,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -656,6 +648,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x16c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -675,6 +668,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x32c8__neoni8mm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -697,6 +691,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c4__neondotfp16arith, xnn_init_f16_qc4w_minmax_scalar_params, @@ -719,6 +714,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x16__neonfp16arith_mlal_lane, xnn_init_f16_qc4w_minmax_scalar_params, @@ -738,6 +734,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x16__neonfp16arith_mlal_lane, xnn_init_f16_qc4w_minmax_scalar_params, @@ -760,6 +757,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -779,6 +777,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -798,6 +797,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -817,6 +817,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -836,6 +837,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -855,6 +857,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -877,6 +880,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -896,6 +900,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -915,6 +920,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -937,6 +943,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni, xnn_init_f16_qc4w_minmax_scalar_params, @@ -956,6 +963,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -975,6 +983,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm, xnn_init_f16_qc4w_minmax_scalar_params, @@ -997,6 +1006,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1016,6 +1026,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1038,6 +1049,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, @@ -1057,6 +1069,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx, xnn_init_f16_qc4w_minmax_scalar_params, diff --git a/test/qd8-f16-qc4w-gemm-minmax.yaml b/test/qd8-f16-qc4w-gemm-minmax.yaml index 4212227b6473..74095e68f27b 100644 --- a/test/qd8-f16-qc4w-gemm-minmax.yaml +++ b/test/qd8-f16-qc4w-gemm-minmax.yaml @@ -8,134 +8,166 @@ init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX256SDK MADD - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True # ARM NEONI8MM @@ -333,201 +365,249 @@ init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX256 VNNI GFNI - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVXVNNI - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm init: xnn_init_f16_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # x86 AVX2 - name: xnn_qd8_f16_qc4w_gemm_minmax_ukernel_1x8c8__avx2 diff --git a/test/qd8-f16-qc8w-gemm-minmax-2.cc b/test/qd8-f16-qc8w-gemm-minmax-2.cc index 4f7c5cd24d87..613d51f43ccc 100644 --- a/test/qd8-f16-qc8w-gemm-minmax-2.cc +++ b/test/qd8-f16-qc8w-gemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm, xnn_init_f16_minmax_scalar_params, @@ -373,6 +355,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -392,6 +375,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -411,6 +395,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -430,6 +415,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -449,6 +435,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -468,6 +455,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -490,6 +478,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -509,6 +498,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -528,6 +518,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -547,6 +538,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -569,6 +561,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -588,6 +581,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -607,6 +601,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -626,6 +621,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -645,6 +641,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -667,6 +664,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -686,6 +684,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -705,6 +704,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -727,6 +727,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, @@ -746,6 +747,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-gemm-minmax-3.cc b/test/qd8-f16-qc8w-gemm-minmax-3.cc index feb827f587e9..9b7948d28294 100644 --- a/test/qd8-f16-qc8w-gemm-minmax-3.cc +++ b/test/qd8-f16-qc8w-gemm-minmax-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -430,6 +415,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -449,6 +435,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -468,6 +455,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -487,6 +475,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -506,6 +495,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -528,6 +518,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -550,6 +541,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -569,6 +561,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -588,6 +581,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -610,6 +604,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -632,6 +627,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-gemm-minmax-4.cc b/test/qd8-f16-qc8w-gemm-minmax-4.cc index a3d792c6581e..25a48bd4685d 100644 --- a/test/qd8-f16-qc8w-gemm-minmax-4.cc +++ b/test/qd8-f16-qc8w-gemm-minmax-4.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -408,6 +392,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -427,6 +412,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -449,6 +435,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -471,6 +458,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c2s4__neonfp16arith, xnn_init_f16_minmax_scalar_params, @@ -490,6 +478,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c2s4__neonfp16arith, xnn_init_f16_minmax_scalar_params, @@ -512,6 +501,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondotfp16arith_cortex_a55, xnn_init_f16_minmax_scalar_params, @@ -534,6 +524,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -553,6 +544,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -572,6 +564,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -594,6 +587,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -613,6 +607,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -632,6 +627,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -651,6 +647,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -670,6 +667,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -689,6 +687,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -711,6 +710,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -730,6 +730,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -752,6 +753,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-gemm-minmax.cc b/test/qd8-f16-qc8w-gemm-minmax.cc index 4cc1815c005e..864b3e70b77d 100644 --- a/test/qd8-f16-qc8w-gemm-minmax.cc +++ b/test/qd8-f16-qc8w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx, xnn_init_f16_minmax_scalar_params, @@ -335,6 +315,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -354,6 +335,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -373,6 +355,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -392,6 +375,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -411,6 +395,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x32c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -433,6 +418,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -455,6 +441,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondotfp16arith_cortex_a55, xnn_init_f16_minmax_scalar_params, @@ -474,6 +461,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondotfp16arith_ld128, xnn_init_f16_minmax_scalar_params, @@ -496,6 +484,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -518,6 +507,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -537,6 +527,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -556,6 +547,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -575,6 +567,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -594,6 +587,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -613,6 +607,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -632,6 +627,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -654,6 +650,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -673,6 +670,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -692,6 +690,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -711,6 +710,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -733,6 +733,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avx2, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-gemm-minmax.yaml b/test/qd8-f16-qc8w-gemm-minmax.yaml index dce4ffbff31a..0099ed37f20c 100644 --- a/test/qd8-f16-qc8w-gemm-minmax.yaml +++ b/test/qd8-f16-qc8w-gemm-minmax.yaml @@ -198,134 +198,166 @@ init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVXVNNI - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True # x86 AVX2 - name: xnn_qd8_f16_qc8w_gemm_minmax_ukernel_1x8c8__avx2 diff --git a/test/qd8-f16-qc8w-igemm-minmax-2.cc b/test/qd8-f16-qc8w-igemm-minmax-2.cc index f1dc2dd4ce15..6e6741c64034 100644 --- a/test/qd8-f16-qc8w-igemm-minmax-2.cc +++ b/test/qd8-f16-qc8w-igemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -373,6 +355,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -392,6 +375,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -411,6 +395,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -433,6 +418,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x16c4__asm_aarch64_neondotfp16arith_cortex_a55, xnn_init_f16_minmax_scalar_params, @@ -455,6 +441,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -474,6 +461,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x32c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -496,6 +484,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c2s4__neonfp16arith, xnn_init_f16_minmax_scalar_params, @@ -515,6 +504,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c2s4__neonfp16arith, xnn_init_f16_minmax_scalar_params, @@ -537,6 +527,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -556,6 +547,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -578,6 +570,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -600,6 +593,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -619,6 +613,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -638,6 +633,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -660,6 +656,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-igemm-minmax-3.cc b/test/qd8-f16-qc8w-igemm-minmax-3.cc index fdd93b202aca..a5726bea01cd 100644 --- a/test/qd8-f16-qc8w-igemm-minmax-3.cc +++ b/test/qd8-f16-qc8w-igemm-minmax-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm, xnn_init_f16_minmax_scalar_params, @@ -373,6 +355,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -392,6 +375,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -414,6 +398,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x32c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -433,6 +418,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x32c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -455,6 +441,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x16c4__asm_aarch64_neondotfp16arith_ld128, xnn_init_f16_minmax_scalar_params, @@ -477,6 +464,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x32c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -496,6 +484,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -515,6 +504,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -534,6 +524,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -556,6 +547,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c4__asm_aarch32_neondotfp16arith_cortex_a55, xnn_init_f16_minmax_scalar_params, @@ -578,6 +570,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -600,6 +593,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -619,6 +613,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -638,6 +633,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -657,6 +653,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -676,6 +673,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -695,6 +693,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -714,6 +713,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -733,6 +733,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -752,6 +753,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -771,6 +773,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -793,6 +796,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -812,6 +816,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -831,6 +836,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -853,6 +859,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, @@ -872,6 +879,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-igemm-minmax-4.cc b/test/qd8-f16-qc8w-igemm-minmax-4.cc index c847e8c8d5ba..ac2627ead9e1 100644 --- a/test/qd8-f16-qc8w-igemm-minmax-4.cc +++ b/test/qd8-f16-qc8w-igemm-minmax-4.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx, xnn_init_f16_minmax_scalar_params, @@ -335,6 +315,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -354,6 +335,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -373,6 +355,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -395,6 +378,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -414,6 +398,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x32c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -436,6 +421,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -455,6 +441,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -477,6 +464,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -496,6 +484,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -515,6 +504,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -534,6 +524,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -553,6 +544,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -572,6 +564,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-igemm-minmax.cc b/test/qd8-f16-qc8w-igemm-minmax.cc index ddee7beb9c07..d7ef296e8bf5 100644 --- a/test/qd8-f16-qc8w-igemm-minmax.cc +++ b/test/qd8-f16-qc8w-igemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f16_minmax_scalar_params, @@ -392,6 +375,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -411,6 +395,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x16c4__neondotfp16arith, xnn_init_f16_minmax_scalar_params, @@ -433,6 +418,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx2, xnn_init_f16_minmax_scalar_params, @@ -455,6 +441,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -474,6 +461,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -493,6 +481,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni, xnn_init_f16_minmax_scalar_params, @@ -515,6 +504,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -534,6 +524,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni, xnn_init_f16_minmax_scalar_params, @@ -553,6 +544,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -572,6 +564,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm, xnn_init_f16_minmax_scalar_params, @@ -594,6 +587,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256skx, xnn_init_f16_minmax_scalar_params, diff --git a/test/qd8-f16-qc8w-igemm-minmax.yaml b/test/qd8-f16-qc8w-igemm-minmax.yaml index 1601e3b4adb3..d6868c322a3a 100644 --- a/test/qd8-f16-qc8w-igemm-minmax.yaml +++ b/test/qd8-f16-qc8w-igemm-minmax.yaml @@ -180,134 +180,166 @@ init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True # x86 AVX VNNI - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm init: xnn_init_f16_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True # x86 AVX256SKX - name: xnn_qd8_f16_qc8w_igemm_minmax_ukernel_1x8c8__avx256skx diff --git a/test/qd8-f32-qb4w-gemm-minmax.cc b/test/qd8-f32-qb4w-gemm-minmax.cc index e3135adff8dc..dff3c4f4aa5f 100644 --- a/test/qd8-f32-qb4w-gemm-minmax.cc +++ b/test/qd8-f32-qb4w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -55,14 +56,6 @@ std::vector CreateTests1( .b_zero_point(8) .bl(32) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - .bl(32) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -120,6 +113,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x2__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -137,6 +131,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -154,6 +149,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -171,6 +167,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x2__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -188,6 +185,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -205,6 +203,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x8__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -222,6 +221,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4__scalar, xnn_init_f32_qb4w_minmax_scalar_params, @@ -240,6 +240,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__avx_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -259,6 +260,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__avx_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -278,6 +280,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__avx_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -297,6 +300,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__avx_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -316,6 +320,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__avx_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -335,6 +340,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__avx_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -354,6 +360,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__avx_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -373,6 +380,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__avx_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -392,6 +400,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse2_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -411,6 +420,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse2_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -430,6 +440,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse2_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -449,6 +460,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse2_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -468,6 +480,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse2_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -487,6 +500,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse2_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -506,6 +520,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse2_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -525,6 +540,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse2_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -544,6 +560,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -563,6 +580,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -582,6 +600,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -601,6 +620,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld128, xnn_init_f32_qb4w_minmax_scalar_params, @@ -620,6 +640,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x4c8__sse41_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -639,6 +660,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x4c8__sse41_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -658,6 +680,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x4c8__sse41_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -677,6 +700,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x4c8__sse41_ld64, xnn_init_f32_qb4w_minmax_scalar_params, @@ -699,6 +723,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -718,6 +743,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -737,6 +763,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x8c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -756,6 +783,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x16c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -775,6 +803,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x8c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -794,6 +823,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x16c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -813,6 +843,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x8c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -832,6 +863,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x16c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -851,6 +883,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x8c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -870,6 +903,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -889,6 +923,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x8c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -908,6 +943,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x16c4__neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -930,6 +966,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8c8__avx2, xnn_init_f32_qb4w_minmax_scalar_params, @@ -949,6 +986,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x8c8__avx2, xnn_init_f32_qb4w_minmax_scalar_params, @@ -968,6 +1006,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x8c8__avx2, xnn_init_f32_qb4w_minmax_scalar_params, @@ -987,6 +1026,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x8c8__avx2, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1009,6 +1049,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1028,6 +1069,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16__neon_mlal_lane_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1047,6 +1089,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x16__neon_mlal_lane, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1066,6 +1109,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x16__neon_mlal_lane_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1085,6 +1129,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x16__neon_mlal_lane, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1104,6 +1149,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x16__neon_mlal_lane_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1123,6 +1169,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x16__neon_mlal_lane, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1142,6 +1189,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x16__neon_mlal_lane_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1161,6 +1209,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x16__neon_mlal_lane, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1180,6 +1229,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x16__neon_mlal_lane_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1202,6 +1252,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1221,6 +1272,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1240,6 +1292,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1259,6 +1312,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1278,6 +1332,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1297,6 +1352,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_2x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1316,6 +1372,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1335,6 +1392,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1354,6 +1412,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_3x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1373,6 +1432,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1392,6 +1452,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1411,6 +1472,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_4x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1430,6 +1492,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1449,6 +1512,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1468,6 +1532,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1487,6 +1552,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1506,6 +1572,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1525,6 +1592,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_6x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1544,6 +1612,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1563,6 +1632,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1582,6 +1652,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1601,6 +1672,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1620,6 +1692,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1639,6 +1712,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x32c8__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1661,6 +1735,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1680,6 +1755,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1699,6 +1775,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1718,6 +1795,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1737,6 +1815,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1756,6 +1835,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1775,6 +1855,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1794,6 +1875,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1813,6 +1895,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1832,6 +1915,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1851,6 +1935,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1870,6 +1955,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1889,6 +1975,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1908,6 +1995,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1927,6 +2015,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1946,6 +2035,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1968,6 +2058,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -1987,6 +2078,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2006,6 +2098,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2025,6 +2118,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2044,6 +2138,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2063,6 +2158,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2082,6 +2178,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2101,6 +2198,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2120,6 +2218,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2139,6 +2238,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2158,6 +2258,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2177,6 +2278,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2196,6 +2298,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2215,6 +2318,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2234,6 +2338,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -2253,6 +2358,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/32, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm, xnn_init_f32_qb4w_minmax_scalar_params, diff --git a/test/qd8-f32-qb4w-gemm-minmax.yaml b/test/qd8-f32-qb4w-gemm-minmax.yaml index 3c1c4e5f9df1..49a106a0b496 100644 --- a/test/qd8-f32-qb4w-gemm-minmax.yaml +++ b/test/qd8-f32-qb4w-gemm-minmax.yaml @@ -375,131 +375,163 @@ init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True # AVX512VNNIGFNI - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qb4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qb4w_minmax_scalar_params pack: xnn_pack_qs8_qb4w_gemm_goi_w k-block: 32 + unsigned-inputs: True diff --git a/test/qd8-f32-qc4w-gemm-minmax-2.cc b/test/qd8-f32-qc4w-gemm-minmax-2.cc index c3d446a84802..66bc1f6f935a 100644 --- a/test/qd8-f32-qc4w-gemm-minmax-2.cc +++ b/test/qd8-f32-qc4w-gemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -340,6 +316,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -349,7 +326,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -360,13 +337,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -488,15 +458,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -525,15 +486,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -653,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -672,6 +625,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -691,6 +645,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -710,6 +665,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -732,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -751,6 +708,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -770,6 +728,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -789,6 +748,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -808,6 +768,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -827,6 +788,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -849,6 +811,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -868,6 +831,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -887,6 +851,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -906,6 +871,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -928,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -947,6 +914,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -966,6 +934,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -985,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1004,6 +974,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1023,6 +994,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1042,6 +1014,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1064,6 +1037,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x32c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1083,6 +1057,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x32c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1105,6 +1080,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1127,6 +1103,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1146,6 +1123,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1165,6 +1143,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1184,6 +1163,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1203,6 +1183,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1222,6 +1203,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1244,6 +1226,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16__neon_mlal_lane, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1263,6 +1246,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16__neon_mlal_lane_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1282,6 +1266,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16__neon_mlal_lane, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1301,6 +1286,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16__neon_mlal_lane_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1323,6 +1309,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1345,6 +1332,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1364,6 +1352,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1383,6 +1372,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1402,6 +1392,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1421,6 +1412,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1440,6 +1432,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1459,6 +1452,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1478,6 +1472,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1500,6 +1495,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1519,6 +1515,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1538,6 +1535,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1557,6 +1555,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1576,6 +1575,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1595,6 +1595,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1614,6 +1615,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1636,6 +1638,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1658,6 +1661,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1677,6 +1681,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1696,6 +1701,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1718,6 +1724,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1737,6 +1744,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1756,6 +1764,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1775,6 +1784,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1797,6 +1807,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1816,6 +1827,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1838,6 +1850,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1857,6 +1870,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1879,6 +1893,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__avx_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1898,6 +1913,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__avx_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1917,6 +1933,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__avx_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1936,6 +1953,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__avx_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1955,6 +1973,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1974,6 +1993,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1993,6 +2013,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse2_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2012,6 +2033,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse2_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2031,6 +2053,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse2_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2050,6 +2073,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2069,6 +2093,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2091,6 +2116,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2110,6 +2136,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2132,6 +2159,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2150,6 +2178,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x2__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2167,6 +2196,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2185,6 +2215,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2201,6 +2232,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8__wasm, xnn_init_f32_qc4w_minmax_scalar_params, diff --git a/test/qd8-f32-qc4w-gemm-minmax-3.cc b/test/qd8-f32-qc4w-gemm-minmax-3.cc index 0db5075cfca2..e712d97e8689 100644 --- a/test/qd8-f32-qc4w-gemm-minmax-3.cc +++ b/test/qd8-f32-qc4w-gemm-minmax-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -340,6 +316,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -349,7 +326,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -360,13 +337,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -488,15 +458,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -525,15 +486,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -653,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -675,6 +628,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -694,6 +648,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -713,6 +668,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -735,6 +691,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -754,6 +711,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -773,6 +731,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -792,6 +751,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -811,6 +771,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -830,6 +791,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -849,6 +811,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -868,6 +831,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -887,6 +851,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -909,6 +874,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -928,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -947,6 +914,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -969,6 +937,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -988,6 +957,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1007,6 +977,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1029,6 +1000,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1048,6 +1020,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1067,6 +1040,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1086,6 +1060,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1105,6 +1080,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1124,6 +1100,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1146,6 +1123,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16__neon_mlal_lane_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1165,6 +1143,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16__neon_mlal_lane, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1184,6 +1163,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16__neon_mlal_lane, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1206,6 +1186,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1225,6 +1206,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1247,6 +1229,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1266,6 +1249,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1285,6 +1269,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1304,6 +1289,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1323,6 +1309,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1342,6 +1329,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1361,6 +1349,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1380,6 +1369,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1399,6 +1389,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1421,6 +1412,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1440,6 +1432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1459,6 +1452,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1478,6 +1472,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1497,6 +1492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1516,6 +1512,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1535,6 +1532,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1554,6 +1552,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1573,6 +1572,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1592,6 +1592,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1614,6 +1615,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1633,6 +1635,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1652,6 +1655,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1671,6 +1675,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1693,6 +1698,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1712,6 +1718,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1731,6 +1738,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1750,6 +1758,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1769,6 +1778,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1788,6 +1798,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1810,6 +1821,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1829,6 +1841,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1848,6 +1861,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1867,6 +1881,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1889,6 +1904,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1911,6 +1927,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1930,6 +1947,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1952,6 +1970,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1971,6 +1990,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse2_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1993,6 +2013,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2012,6 +2033,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/6, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2034,6 +2056,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2052,6 +2075,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2069,6 +2093,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2087,6 +2112,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x2__wasm, xnn_init_f32_qc4w_minmax_scalar_params, diff --git a/test/qd8-f32-qc4w-gemm-minmax-4.cc b/test/qd8-f32-qc4w-gemm-minmax-4.cc index 441abfd323ef..3c54e40a16bb 100644 --- a/test/qd8-f32-qc4w-gemm-minmax-4.cc +++ b/test/qd8-f32-qc4w-gemm-minmax-4.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -340,6 +316,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -349,7 +326,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -360,13 +337,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -488,15 +458,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -525,15 +486,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -653,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -672,6 +625,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -694,6 +648,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -713,6 +668,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -732,6 +688,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -751,6 +708,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -770,6 +728,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -792,6 +751,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -811,6 +771,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -830,6 +791,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -849,6 +811,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -868,6 +831,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -890,6 +854,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -909,6 +874,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -928,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -947,6 +914,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -966,6 +934,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -985,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1004,6 +974,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1023,6 +994,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1042,6 +1014,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1064,6 +1037,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x64c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1083,6 +1057,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1102,6 +1077,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1124,6 +1100,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1143,6 +1120,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1162,6 +1140,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1184,6 +1163,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1203,6 +1183,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1222,6 +1203,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1241,6 +1223,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1260,6 +1243,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1279,6 +1263,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1298,6 +1283,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1320,6 +1306,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16__neon_mlal_lane, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1342,6 +1329,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1361,6 +1349,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1380,6 +1369,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1399,6 +1389,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1418,6 +1409,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1440,6 +1432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1459,6 +1452,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1478,6 +1472,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1497,6 +1492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1516,6 +1512,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1535,6 +1532,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1554,6 +1552,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1576,6 +1575,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1595,6 +1595,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1614,6 +1615,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1633,6 +1635,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1652,6 +1655,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1671,6 +1675,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1690,6 +1695,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1709,6 +1715,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1728,6 +1735,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1747,6 +1755,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1766,6 +1775,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1788,6 +1798,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1807,6 +1818,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1829,6 +1841,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1848,6 +1861,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1870,6 +1884,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1889,6 +1904,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1908,6 +1924,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1930,6 +1947,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1949,6 +1967,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1968,6 +1987,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1990,6 +2010,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2012,6 +2033,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__avx_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2031,6 +2053,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__avx_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2050,6 +2073,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2069,6 +2093,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2091,6 +2116,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2110,6 +2136,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/5, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2132,6 +2159,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2150,6 +2178,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2167,6 +2196,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2185,6 +2215,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x2__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2201,6 +2232,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4__wasm, xnn_init_f32_qc4w_minmax_scalar_params, diff --git a/test/qd8-f32-qc4w-gemm-minmax.cc b/test/qd8-f32-qc4w-gemm-minmax.cc index 67b9d4054c98..20de62f9810f 100644 --- a/test/qd8-f32-qc4w-gemm-minmax.cc +++ b/test/qd8-f32-qc4w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -54,13 +55,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -182,15 +176,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -219,15 +204,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -340,6 +316,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -349,7 +326,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -360,13 +337,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -488,15 +458,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -525,15 +486,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -653,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -672,6 +625,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -691,6 +645,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -710,6 +665,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -729,6 +685,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -748,6 +705,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -767,6 +725,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -786,6 +745,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -805,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -824,6 +785,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -846,6 +808,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -865,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -884,6 +848,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -903,6 +868,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -925,6 +891,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -944,6 +911,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -963,6 +931,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -982,6 +951,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1004,6 +974,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1023,6 +994,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1042,6 +1014,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1061,6 +1034,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1080,6 +1054,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1099,6 +1074,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1118,6 +1094,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1137,6 +1114,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1156,6 +1134,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1178,6 +1157,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x64c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1197,6 +1177,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x64c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1216,6 +1197,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1235,6 +1217,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1257,6 +1240,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1276,6 +1260,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__neondot, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1298,6 +1283,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1317,6 +1303,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1336,6 +1323,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1355,6 +1343,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x32c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1374,6 +1363,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1396,6 +1386,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x16__neon_mlal_lane_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1415,6 +1406,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16__neon_mlal_lane_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1437,6 +1429,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1456,6 +1449,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1475,6 +1469,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1494,6 +1489,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1513,6 +1509,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1532,6 +1529,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1551,6 +1549,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1570,6 +1569,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1589,6 +1589,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1608,6 +1609,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1630,6 +1632,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1649,6 +1652,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1668,6 +1672,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1687,6 +1692,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1706,6 +1712,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1725,6 +1732,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1747,6 +1755,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1766,6 +1775,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1785,6 +1795,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1804,6 +1815,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1823,6 +1835,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1842,6 +1855,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1861,6 +1875,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1880,6 +1895,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1899,6 +1915,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1921,6 +1938,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1940,6 +1958,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1959,6 +1978,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1978,6 +1998,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -1997,6 +2018,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2019,6 +2041,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2038,6 +2061,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2057,6 +2081,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2076,6 +2101,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2095,6 +2121,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2117,6 +2144,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2136,6 +2164,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2158,6 +2187,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2177,6 +2207,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2196,6 +2227,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2218,6 +2250,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__avx_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2237,6 +2270,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__avx_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2256,6 +2290,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_ld128, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2275,6 +2310,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2294,6 +2330,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2313,6 +2350,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2332,6 +2370,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2354,6 +2393,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2373,6 +2413,7 @@ std::vector CreateTests1( /*adj_k_block=*/2, /*mr=*/8, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x4v__rvv, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2395,6 +2436,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2413,6 +2455,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/1, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x1__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2430,6 +2473,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x2__scalar, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2448,6 +2492,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4__wasm, xnn_init_f32_qc4w_minmax_scalar_params, @@ -2464,6 +2509,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8__wasm, xnn_init_f32_qc4w_minmax_scalar_params, diff --git a/test/qd8-f32-qc4w-gemm-minmax.yaml b/test/qd8-f32-qc4w-gemm-minmax.yaml index c0a945942902..a3bdd7d7586d 100644 --- a/test/qd8-f32-qc4w-gemm-minmax.yaml +++ b/test/qd8-f32-qc4w-gemm-minmax.yaml @@ -8,345 +8,426 @@ init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__ssse3_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__ssse3_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__ssse3_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__ssse3_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__sse41_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x4c8__sse41_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x4c8__sse41_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__sse41_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX256SKX MADD - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 - + unsigned-inputs: True # AVX2 MADD - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx2_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX512SKX MADD - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512skx_madd_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4uw_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX512AMX - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x64c4__avx512amx @@ -618,483 +699,599 @@ init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX512 VNNI GFNI - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c4__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x16c8__avx512vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX256 VNNI - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVX256 VNNI GFNI - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_9x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_10x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_12x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_14x8c8__avx256vnnigfni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVXVNNI - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm init: xnn_init_f32_qc4w_minmax_scalar_params pack: xnn_pack_qs8_qc4w_gemm_goi_w k-block: 16 + unsigned-inputs: True # x86 AVX2 - name: xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x8c8__avx2 diff --git a/test/qd8-f32-qc8w-gemm-minmax-2.cc b/test/qd8-f32-qc8w-gemm-minmax-2.cc index 10fa5fae7346..6b00ff50cf63 100644 --- a/test/qd8-f32-qc8w-gemm-minmax-2.cc +++ b/test/qd8-f32-qc8w-gemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -307,6 +286,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -316,7 +296,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -326,12 +306,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -439,14 +413,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -472,14 +438,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -587,6 +545,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -606,6 +565,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x64c4__avx512amx_prfm, xnn_init_f32_minmax_scalar_params, @@ -625,6 +585,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -644,6 +605,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -663,6 +625,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx_prfm, xnn_init_f32_minmax_scalar_params, @@ -677,6 +640,232 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_5X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/5, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_9X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/9, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_10X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/10, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_5X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_9X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_10X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_1X8C8__NEONI8MM, GemmTest, @@ -685,6 +874,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -704,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -723,6 +914,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -742,6 +934,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -761,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -783,6 +977,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c2s4__neon_mlal, xnn_init_f32_minmax_scalar_params, @@ -805,6 +1000,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -827,6 +1023,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -846,6 +1043,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -865,6 +1063,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -884,6 +1083,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -906,6 +1106,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -925,6 +1126,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -944,6 +1146,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -966,6 +1169,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -985,6 +1189,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -1007,6 +1212,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1026,6 +1232,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1045,6 +1252,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1064,6 +1272,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1083,6 +1292,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1102,6 +1312,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1121,6 +1332,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1143,6 +1355,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -1165,6 +1378,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1184,6 +1398,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1203,6 +1418,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1222,6 +1438,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1241,6 +1458,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1260,6 +1478,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1279,6 +1498,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1298,6 +1518,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1320,6 +1541,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1339,6 +1561,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1358,6 +1581,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1377,6 +1601,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1396,6 +1621,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1418,6 +1644,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1437,6 +1664,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1456,6 +1684,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1478,6 +1707,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1497,6 +1727,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/5, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1519,6 +1750,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1535,6 +1767,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1551,6 +1784,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1567,6 +1801,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1583,6 +1818,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1599,6 +1835,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1615,6 +1852,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1631,6 +1869,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1650,6 +1889,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1669,6 +1909,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1688,6 +1929,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1707,6 +1949,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1726,6 +1969,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1745,6 +1989,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1767,6 +2012,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -1783,6 +2029,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1801,6 +2048,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x2__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-gemm-minmax-3.cc b/test/qd8-f32-qc8w-gemm-minmax-3.cc index 0d2a63fec0cd..67a4b1f57d7f 100644 --- a/test/qd8-f32-qc8w-gemm-minmax-3.cc +++ b/test/qd8-f32-qc8w-gemm-minmax-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -307,6 +286,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -316,7 +296,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -326,12 +306,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -439,14 +413,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -472,14 +438,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -587,6 +545,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -601,6 +560,192 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_6X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/6, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_2X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_6X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_5X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/5, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_1X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_4X8C4__ASM_AARCH32_NEONDOT_CORTEX_A55, GemmTest, @@ -609,6 +754,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -631,6 +777,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -650,6 +797,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -669,6 +817,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -688,6 +837,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -707,6 +857,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -726,6 +877,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -745,6 +897,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -767,6 +920,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -786,6 +940,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -808,6 +963,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -827,6 +983,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__neondot_ld64, xnn_init_f32_minmax_scalar_params, @@ -846,6 +1003,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -868,6 +1026,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -887,6 +1046,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -909,6 +1069,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -931,6 +1092,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -950,6 +1112,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -969,6 +1132,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -988,6 +1152,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1007,6 +1172,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1026,6 +1192,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1045,6 +1212,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1067,6 +1235,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, @@ -1089,6 +1258,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -1108,6 +1278,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -1130,6 +1301,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1149,6 +1321,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1168,6 +1341,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1187,6 +1361,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1206,6 +1381,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1225,6 +1401,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1244,6 +1421,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1263,6 +1441,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1285,6 +1464,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1304,6 +1484,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1323,6 +1504,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1345,6 +1527,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1364,6 +1547,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1383,6 +1567,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1405,6 +1590,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1424,6 +1610,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/8, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1446,6 +1633,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1462,6 +1650,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1478,6 +1667,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1494,6 +1684,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1510,6 +1701,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1526,6 +1718,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1545,6 +1738,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1564,6 +1758,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1583,6 +1778,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1602,6 +1798,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1621,6 +1818,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1640,6 +1838,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1659,6 +1858,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1678,6 +1878,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1700,6 +1901,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1716,6 +1918,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8__wasm, xnn_init_f32_minmax_scalar_params, @@ -1734,6 +1937,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -1751,6 +1955,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-gemm-minmax-4.cc b/test/qd8-f32-qc8w-gemm-minmax-4.cc index c282bd8041ad..b7df81fb9dbc 100644 --- a/test/qd8-f32-qc8w-gemm-minmax-4.cc +++ b/test/qd8-f32-qc8w-gemm-minmax-4.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -307,6 +286,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -316,7 +296,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -326,12 +306,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -439,14 +413,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -472,14 +438,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -587,6 +545,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -606,6 +565,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -625,6 +585,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x16c4__avx512amx_prfm, xnn_init_f32_minmax_scalar_params, @@ -639,6 +600,152 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_7X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_7X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X16C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_1X32C8__NEONI8MM, GemmTest, @@ -647,6 +754,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -666,6 +774,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -685,6 +794,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -704,6 +814,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -723,6 +834,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -742,6 +854,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -764,6 +877,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c2s4__neon_mlal, xnn_init_f32_minmax_scalar_params, @@ -783,6 +897,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -805,6 +920,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__aarch64_neondot_ld128, xnn_init_f32_minmax_scalar_params, @@ -827,6 +943,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__neondot_ld64, xnn_init_f32_minmax_scalar_params, @@ -849,6 +966,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -868,6 +986,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -890,6 +1009,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -912,6 +1032,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -934,6 +1055,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -956,6 +1078,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -978,6 +1101,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -997,6 +1121,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -1019,6 +1144,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -1038,6 +1164,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld64, xnn_init_f32_minmax_scalar_params, @@ -1057,6 +1184,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld128, xnn_init_f32_minmax_scalar_params, @@ -1079,6 +1207,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1098,6 +1227,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -1117,6 +1247,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -1136,6 +1267,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1155,6 +1287,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -1174,6 +1307,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1193,6 +1327,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1212,6 +1347,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1234,6 +1370,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, @@ -1256,6 +1393,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -1275,6 +1413,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -1294,6 +1433,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -1313,6 +1453,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -1332,6 +1473,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -1354,6 +1496,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1373,6 +1516,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1392,6 +1536,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1411,6 +1556,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1430,6 +1576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1449,6 +1596,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1468,6 +1616,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1487,6 +1636,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1506,6 +1656,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1525,6 +1676,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1547,6 +1699,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1569,6 +1722,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1588,6 +1742,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1607,6 +1762,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1626,6 +1782,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1645,6 +1802,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1664,6 +1822,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1683,6 +1842,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1702,6 +1862,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1721,6 +1882,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1740,6 +1902,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1759,6 +1922,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1778,6 +1942,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u2_acc2, xnn_init_f32_minmax_scalar_params, @@ -1797,6 +1962,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1816,6 +1982,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1835,6 +2002,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1854,6 +2022,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1873,6 +2042,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/5, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1892,6 +2062,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1911,6 +2082,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/7, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1930,6 +2102,7 @@ std::vector CreateTests1( /*adj_k_block=*/32, /*mr=*/8, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u4_acc4, xnn_init_f32_minmax_scalar_params, @@ -1952,6 +2125,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1971,6 +2145,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/7, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1993,6 +2168,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -2009,6 +2185,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -2025,6 +2202,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -2041,6 +2219,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -2060,6 +2239,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2079,6 +2259,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2098,6 +2279,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2117,6 +2299,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -2136,6 +2319,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2155,6 +2339,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -2174,6 +2359,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2193,6 +2379,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -2212,6 +2399,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2231,6 +2419,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2250,6 +2439,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -2269,6 +2459,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -2288,6 +2479,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -2307,6 +2499,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -2326,6 +2519,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -2348,6 +2542,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -2364,6 +2559,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8__wasm, xnn_init_f32_minmax_scalar_params, @@ -2382,6 +2578,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -2399,6 +2596,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-gemm-minmax.cc b/test/qd8-f32-qc8w-gemm-minmax.cc index d59f9299c97f..dcb2e30980a8 100644 --- a/test/qd8-f32-qc8w-gemm-minmax.cc +++ b/test/qd8-f32-qc8w-gemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -307,6 +286,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -316,7 +296,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -326,12 +306,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -439,14 +413,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1, 4) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1, 4) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -472,14 +438,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -587,6 +545,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -606,6 +565,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x64c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -625,6 +585,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_16x32c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -639,6 +600,172 @@ std::vector CreateTests1( #endif // XNN_ENABLE_AVX512AMX && (XNN_ARCH_X86 || XNN_ARCH_X86_64) +#if XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_8X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_11X32C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/11, /*nr=*/32, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_4X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_8X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_11X16C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/11, /*nr=*/16, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X64C4__ASM_AMD64_AVX512VNNI, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/64, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/true, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512VNNI; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_AVX512VNNI && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QD8_F32_QC8W_GEMM_MINMAX_3X8C4__ASM_AARCH64_NEONDOT_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/4, + /*adj_k_block=*/4, + /*mr=*/3, /*nr=*/8, /*kr=*/4, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane, + xnn_init_f32_minmax_scalar_params, + xnn_pack_qs8_gemm_goi_w); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + #if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 INSTANTIATE_TEST_SUITE_P( QD8_F32_QC8W_GEMM_MINMAX_1X16C8__NEONI8MM, GemmTest, @@ -647,6 +774,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -666,6 +794,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -685,6 +814,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -704,6 +834,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -723,6 +854,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/32, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -742,6 +874,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -764,6 +897,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__aarch64_neondot_ld128, xnn_init_f32_minmax_scalar_params, @@ -786,6 +920,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -805,6 +940,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -824,6 +960,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -843,6 +980,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -862,6 +1000,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -884,6 +1023,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -906,6 +1046,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -925,6 +1066,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -944,6 +1086,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -963,6 +1106,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -982,6 +1126,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1001,6 +1146,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1023,6 +1169,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, @@ -1042,6 +1189,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, @@ -1064,6 +1212,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1083,6 +1232,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1102,6 +1252,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1121,6 +1272,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1140,6 +1292,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1159,6 +1312,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1178,6 +1332,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1197,6 +1352,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1219,6 +1375,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1238,6 +1395,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1257,6 +1415,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1276,6 +1435,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1295,6 +1455,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1314,6 +1475,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1333,6 +1495,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1355,6 +1518,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1374,6 +1538,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1393,6 +1558,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1412,6 +1578,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1431,6 +1598,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1450,6 +1618,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1472,6 +1641,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1491,6 +1661,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/6, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x4v__rvv, xnn_init_f32_minmax_scalar_params, @@ -1513,6 +1684,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1529,6 +1701,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1545,6 +1718,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1561,6 +1735,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1577,6 +1752,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1593,6 +1769,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1612,6 +1789,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1631,6 +1809,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x4c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1650,6 +1829,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c16__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1669,6 +1849,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1688,6 +1869,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1707,6 +1889,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1726,6 +1909,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1748,6 +1932,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -1766,6 +1951,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -1783,6 +1969,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-gemm-minmax.yaml b/test/qd8-f32-qc8w-gemm-minmax.yaml index 13e249757c39..3967610f9ea5 100644 --- a/test/qd8-f32-qc8w-gemm-minmax.yaml +++ b/test/qd8-f32-qc8w-gemm-minmax.yaml @@ -55,6 +55,176 @@ pack: xnn_pack_qs8_gemm_goi_w k-block: 64 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True + +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True + +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 + unsigned-inputs: True + +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 +- name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_lane + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_qs8_gemm_goi_w + k-block: 4 # AArch32 assembly - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55 init: xnn_init_f32_minmax_scalar_params @@ -501,342 +671,424 @@ init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True # x86 AVX512VNNI - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_12x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_14x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True # AVXVNNI - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True # x86 AVXVNNI - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u2_acc2 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x8c4__avxvnni_u4_acc4 init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_gemm_goi_w k-block: 32 + unsigned-inputs: True # RISC-V Vector - name: xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4v__rvv diff --git a/test/qd8-f32-qc8w-igemm-minmax-2.cc b/test/qd8-f32-qc8w-igemm-minmax-2.cc index 7e0644069e85..13d572925921 100644 --- a/test/qd8-f32-qc8w-igemm-minmax-2.cc +++ b/test/qd8-f32-qc8w-igemm-minmax-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx_prfm, xnn_init_f32_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x32c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x64c4__avx512amx_prfm, xnn_init_f32_minmax_scalar_params, @@ -391,6 +374,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -408,6 +392,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -425,6 +410,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4__scalar, xnn_init_f32_minmax_scalar_params, @@ -443,6 +429,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__aarch64_neondot_ld128, xnn_init_f32_minmax_scalar_params, @@ -465,6 +452,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -487,6 +475,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -509,6 +498,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__aarch64_neondot_ld128, xnn_init_f32_minmax_scalar_params, @@ -531,6 +521,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x32c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -553,6 +544,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c2s4__neon_mlal, xnn_init_f32_minmax_scalar_params, @@ -572,6 +564,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -591,6 +584,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -613,6 +607,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -632,6 +627,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -651,6 +647,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -670,6 +667,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -692,6 +690,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -711,6 +710,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -730,6 +730,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -752,6 +753,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -771,6 +773,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -790,6 +793,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -809,6 +813,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -828,6 +833,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -847,6 +853,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -866,6 +873,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -885,6 +893,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -904,6 +913,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -926,6 +936,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -945,6 +956,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -964,6 +976,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -983,6 +996,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1002,6 +1016,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1021,6 +1036,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1040,6 +1056,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1062,6 +1079,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1081,6 +1099,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1100,6 +1119,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1119,6 +1139,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1138,6 +1159,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1157,6 +1179,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1176,6 +1199,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1198,6 +1222,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -1214,6 +1239,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1233,6 +1259,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1252,6 +1279,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1271,6 +1299,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1290,6 +1319,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1309,6 +1339,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1328,6 +1359,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1350,6 +1382,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1366,6 +1399,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1382,6 +1416,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1398,6 +1433,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1414,6 +1450,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1430,6 +1467,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1446,6 +1484,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1465,6 +1504,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1484,6 +1524,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1503,6 +1544,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -1522,6 +1564,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1541,6 +1584,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1560,6 +1604,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -1579,6 +1624,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1598,6 +1644,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1620,6 +1667,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, @@ -1639,6 +1687,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-igemm-minmax-3.cc b/test/qd8-f32-qc8w-igemm-minmax-3.cc index 35e41fac7cd7..76316c5a194e 100644 --- a/test/qd8-f32-qc8w-igemm-minmax-3.cc +++ b/test/qd8-f32-qc8w-igemm-minmax-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x32c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx_prfm, xnn_init_f32_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x64c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -391,6 +374,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8__scalar, xnn_init_f32_minmax_scalar_params, @@ -409,6 +393,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -431,6 +416,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -453,6 +439,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__neondot_ld64, xnn_init_f32_minmax_scalar_params, @@ -475,6 +462,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -497,6 +485,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -519,6 +508,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -541,6 +531,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -563,6 +554,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -582,6 +574,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x32c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -604,6 +597,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -626,6 +620,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -648,6 +643,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -670,6 +666,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -692,6 +689,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -711,6 +709,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x32c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -733,6 +732,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -752,6 +752,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -771,6 +772,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -793,6 +795,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x32c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -812,6 +815,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -831,6 +835,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x32c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -853,6 +858,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -872,6 +878,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -894,6 +901,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -913,6 +921,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -932,6 +941,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -951,6 +961,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -970,6 +981,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -989,6 +1001,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1008,6 +1021,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1027,6 +1041,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1046,6 +1061,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1065,6 +1081,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1084,6 +1101,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1103,6 +1121,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1122,6 +1141,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1141,6 +1161,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1160,6 +1181,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1182,6 +1204,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1201,6 +1224,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1220,6 +1244,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1242,6 +1267,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1261,6 +1287,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1280,6 +1307,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1302,6 +1330,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8__wasm, xnn_init_f32_minmax_scalar_params, @@ -1318,6 +1347,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8__wasm, xnn_init_f32_minmax_scalar_params, @@ -1337,6 +1367,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1356,6 +1387,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1375,6 +1407,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1394,6 +1427,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1413,6 +1447,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1432,6 +1467,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1451,6 +1487,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1473,6 +1510,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1489,6 +1527,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1505,6 +1544,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1521,6 +1561,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1537,6 +1578,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1553,6 +1595,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1569,6 +1612,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1585,6 +1629,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1601,6 +1646,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1617,6 +1663,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1636,6 +1683,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1655,6 +1703,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1674,6 +1723,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -1693,6 +1743,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1712,6 +1763,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1731,6 +1783,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__sse2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1750,6 +1803,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -1769,6 +1823,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-igemm-minmax.cc b/test/qd8-f32-qc8w-igemm-minmax.cc index c1b7564180ae..cdcc266e73d6 100644 --- a/test/qd8-f32-qc8w-igemm-minmax.cc +++ b/test/qd8-f32-qc8w-igemm-minmax.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -313,6 +292,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -332,6 +312,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x16c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -351,6 +332,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_16x32c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -370,6 +352,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x64c4__avx512amx, xnn_init_f32_minmax_scalar_params, @@ -391,6 +374,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -408,6 +392,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x2__scalar, xnn_init_f32_minmax_scalar_params, @@ -425,6 +410,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8__scalar, xnn_init_f32_minmax_scalar_params, @@ -443,6 +429,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -462,6 +449,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -481,6 +469,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c2s4__neon_mlal, xnn_init_f32_minmax_scalar_params, @@ -503,6 +492,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__neondot_ld64, xnn_init_f32_minmax_scalar_params, @@ -522,6 +512,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -544,6 +535,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -566,6 +558,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -585,6 +578,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x16__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -607,6 +601,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -629,6 +624,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -648,6 +644,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -670,6 +667,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -692,6 +690,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8__neon_mlal_lane, xnn_init_f32_minmax_scalar_params, @@ -711,6 +710,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -733,6 +733,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55, xnn_init_f32_minmax_scalar_params, @@ -755,6 +756,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -777,6 +779,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ld128, xnn_init_f32_minmax_scalar_params, @@ -799,6 +802,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8__neon_mlal_lane_prfm, xnn_init_f32_minmax_scalar_params, @@ -821,6 +825,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -840,6 +845,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x16c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -859,6 +865,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/4, /*mr=*/8, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c4__neondot, xnn_init_f32_minmax_scalar_params, @@ -881,6 +888,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -900,6 +908,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -922,6 +931,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512skx, xnn_init_f32_minmax_scalar_params, @@ -941,6 +951,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -960,6 +971,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512skx_prfm, xnn_init_f32_minmax_scalar_params, @@ -982,6 +994,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1001,6 +1014,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1020,6 +1034,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1039,6 +1054,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1058,6 +1074,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1077,6 +1094,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni, xnn_init_f32_minmax_scalar_params, @@ -1096,6 +1114,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1115,6 +1134,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1134,6 +1154,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1153,6 +1174,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1175,6 +1197,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1194,6 +1217,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni, xnn_init_f32_minmax_scalar_params, @@ -1213,6 +1237,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1232,6 +1257,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1251,6 +1277,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1270,6 +1297,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1292,6 +1320,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni, xnn_init_f32_minmax_scalar_params, @@ -1311,6 +1340,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1330,6 +1360,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1349,6 +1380,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1368,6 +1400,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1387,6 +1420,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/true, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm, xnn_init_f32_minmax_scalar_params, @@ -1409,6 +1443,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x2__wasm, xnn_init_f32_minmax_scalar_params, @@ -1425,6 +1460,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1441,6 +1477,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/2, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4__wasm, xnn_init_f32_minmax_scalar_params, @@ -1460,6 +1497,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1479,6 +1517,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1498,6 +1537,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1517,6 +1557,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__wasmusdot, xnn_init_f32_minmax_scalar_params, @@ -1536,6 +1577,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__wasmusdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1555,6 +1597,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1574,6 +1617,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1593,6 +1637,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1612,6 +1657,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1631,6 +1677,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__wasmsdot_u2, xnn_init_f32_minmax_scalar_params, @@ -1650,6 +1697,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c16__wasmsdot, xnn_init_f32_minmax_scalar_params, @@ -1672,6 +1720,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1688,6 +1737,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c2__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1704,6 +1754,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1720,6 +1771,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c2__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1736,6 +1788,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1752,6 +1805,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1768,6 +1822,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__wasmsimd_dot16x2_ld128, xnn_init_f32_minmax_scalar_params, @@ -1787,6 +1842,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1806,6 +1862,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -1825,6 +1882,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__avx_ld64, xnn_init_f32_minmax_scalar_params, @@ -1844,6 +1902,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1863,6 +1922,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1882,6 +1942,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x4c8__sse41_ld128, xnn_init_f32_minmax_scalar_params, @@ -1901,6 +1962,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x4c8__sse41_ld64, xnn_init_f32_minmax_scalar_params, @@ -1920,6 +1982,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__avx_ld128, xnn_init_f32_minmax_scalar_params, @@ -1939,6 +2002,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x4c8__sse2_ld64, xnn_init_f32_minmax_scalar_params, @@ -1958,6 +2022,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1977,6 +2042,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -1996,6 +2062,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avx2, xnn_init_f32_minmax_scalar_params, @@ -2018,6 +2085,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, @@ -2037,6 +2105,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256skx, xnn_init_f32_minmax_scalar_params, diff --git a/test/qd8-f32-qc8w-igemm-minmax.yaml b/test/qd8-f32-qc8w-igemm-minmax.yaml index aeef211db399..413acaec4c30 100644 --- a/test/qd8-f32-qc8w-igemm-minmax.yaml +++ b/test/qd8-f32-qc8w-igemm-minmax.yaml @@ -351,275 +351,341 @@ init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c4__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 8 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x16c8__avx512vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True # x86 AVX256 VNNI - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_9x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_12x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_14x8c8__avx256vnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True # x86 AVX VNNI - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_2x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_3x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_6x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_7x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_8x8c8__avxvnni_prfm init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_qs8_conv_goki_w k-block: 16 + unsigned-inputs: True # WAsm - name: xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x2__wasm diff --git a/test/qp8-f32-qb4w-gemm-minmax.cc b/test/qp8-f32-qb4w-gemm-minmax.cc index bcad06285c9f..f33db9974ebe 100644 --- a/test/qp8-f32-qb4w-gemm-minmax.cc +++ b/test/qp8-f32-qb4w-gemm-minmax.cc @@ -36,6 +36,7 @@ std::vector CreateTests1( size_t mr, size_t nr, size_t kr, size_t sr, size_t mr_packed, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -44,7 +45,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed); + .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -56,16 +57,6 @@ std::vector CreateTests1( .b_zero_point(8) .bl(32) , test_func, isa_check)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_eq_" + kbs + "_strided_a", - tester.clone() - .m(mr).n(nr).k(k_block) - .a_stride(xnnpack::NextPrime(k_block + 1)) - .b_zero_point(8) - .bl(32) - , test_func, isa_check)); - } gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_subtile", tester.clone() @@ -116,6 +107,7 @@ std::vector CreateTests1( /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/2, /*mr_packed=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -138,6 +130,7 @@ std::vector CreateTests1( /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/2, /*mr_packed=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot, xnn_init_f32_qb4w_minmax_scalar_params, @@ -165,6 +158,7 @@ std::vector CreateTests1( /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm, xnn_init_f32_qb4w_minmax_scalar_params, @@ -187,6 +181,7 @@ std::vector CreateTests1( /*mr=*/8, /*nr=*/4, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2, xnn_init_f32_qb4w_minmax_scalar_params, @@ -209,6 +204,7 @@ std::vector CreateTests1( /*mr=*/16, /*nr=*/4, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qb4w_gemm_minmax_ukernel_16x4c16s2__neoni8mm_mstep4, xnn_init_f32_qb4w_minmax_scalar_params, diff --git a/test/qp8-f32-qc4w-gemm-minmax.cc b/test/qp8-f32-qc4w-gemm-minmax.cc index d6703c6170bc..08fd49d419d7 100644 --- a/test/qp8-f32-qc4w-gemm-minmax.cc +++ b/test/qp8-f32-qc4w-gemm-minmax.cc @@ -36,6 +36,7 @@ std::vector CreateTests1( size_t mr, size_t nr, size_t kr, size_t sr, size_t mr_packed, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -44,7 +45,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed); + .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -55,15 +56,6 @@ std::vector CreateTests1( .m(mr).n(nr).k(k_block) .b_zero_point(8) , test_func, isa_check)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_eq_" + kbs + "_strided_a", - tester.clone() - .m(mr).n(nr).k(k_block) - .a_stride(xnnpack::NextPrime(k_block + 1)) - .b_zero_point(8) - , test_func, isa_check)); - } gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_subtile", tester.clone() @@ -94,16 +86,6 @@ std::vector CreateTests1( .b_zero_point(8) , test_func, isa_check) .loop_k(1, adj_k_block - 1)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_lt_" + akbs + "_strided_a", - tester.clone() - .m(mr).n(nr) - .a_stride(xnnpack::NextPrime(adj_k_block + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_k(1, adj_k_block - 1)); - } gemm_tests.push_back(GemmTestParams( "k_lt_" + akbs + "_subtile", tester.clone() @@ -121,16 +103,6 @@ std::vector CreateTests1( .b_zero_point(8) , test_func, isa_check) .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); - if (is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_gt_" + akbs + "_strided_a", - tester.clone() - .m(mr).n(nr) - .a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); - } gemm_tests.push_back(GemmTestParams( "k_gt_" + akbs + "_subtile", tester.clone() @@ -148,16 +120,6 @@ std::vector CreateTests1( .b_zero_point(8) , test_func, isa_check) .loop_k(adj_k_block + k_block, k_block * 5, k_block)); - if (is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_div_" + kbs + "_strided_a", - tester.clone() - .m(mr).n(nr) - .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_k(adj_k_block + k_block, k_block * 3, k_block)); - } gemm_tests.push_back(GemmTestParams( "k_div_" + kbs + "_subtile", tester.clone() @@ -176,17 +138,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_a", - tester.clone() - .m(mr) - .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block)); - } gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_subtile", tester.clone() @@ -204,17 +155,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_a", - tester.clone() - .m(mr) - .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) - .b_zero_point(8) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block)); - } gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_subtile", tester.clone() @@ -324,6 +264,7 @@ std::vector CreateTests1( /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/2, /*mr_packed=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_1x4c16s2__aarch64_neondot, xnn_init_f32_minmax_scalar_params, @@ -346,6 +287,7 @@ std::vector CreateTests1( /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/2, /*mr_packed=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_1x8c16s2__aarch64_neondot, xnn_init_f32_minmax_scalar_params, @@ -373,6 +315,7 @@ std::vector CreateTests1( /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_4x4c16s2__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -395,6 +338,7 @@ std::vector CreateTests1( /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_4x8c16s2__neoni8mm, xnn_init_f32_minmax_scalar_params, @@ -417,6 +361,7 @@ std::vector CreateTests1( /*mr=*/8, /*nr=*/4, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x4c16s2__neoni8mm_mstep2, xnn_init_f32_minmax_scalar_params, @@ -439,6 +384,7 @@ std::vector CreateTests1( /*mr=*/8, /*nr=*/8, /*kr=*/16, /*sr=*/2, /*mr_packed=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qp8_f32_qc4w_gemm_minmax_ukernel_8x8c16s2__neoni8mm_mstep2, xnn_init_f32_minmax_scalar_params, diff --git a/test/qp8-f32-qc8w-gemm-minmax.cc b/test/qp8-f32-qc8w-gemm-minmax.cc new file mode 100644 index 000000000000..6d781591abd0 --- /dev/null +++ b/test/qp8-f32-qc8w-gemm-minmax.cc @@ -0,0 +1,332 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Auto-generated file. Do not edit! +// Specification: test/qp8-f32-qc8w-gemm-minmax.yaml +// Generator: tools/generate-gemm-test.py + +#include +#include +#include +#include + +#include +#include "xnnpack/allocator.h" +#include "xnnpack/common.h" +#include "xnnpack/gemm.h" +#include "xnnpack/igemm.h" +#include "xnnpack/isa-checks.h" +#include "xnnpack/microparams-init.h" +#include "xnnpack/pack.h" +#include "xnnpack/packw.h" +#include "xnnpack/ppmm.h" +#include "xnnpack/requantization.h" +#include "gemm-microkernel-tester.h" +#include "next_prime.h" + +namespace { + +std::vector CreateTests1( + size_t k_block, size_t adj_k_block, + size_t mr, size_t nr, size_t kr, size_t sr, + size_t mr_packed, + bool is_igemm, + bool unsigned_inputs, + std::function test_func, + std::function isa_check = nullptr) { + std::string kbs = std::to_string(k_block); + std::string kb2s = std::to_string(k_block * 2); + std::string akbs = std::to_string(adj_k_block); + std::string nrs = std::to_string(nr); + + const GemmMicrokernelTester tester = GemmMicrokernelTester() + .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed).unsigned_inputs(unsigned_inputs); + + std::vector gemm_tests; + gemm_tests.reserve(42); + + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs, + tester.clone() + .m(mr).n(nr).k(k_block) + , test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile", + tester.clone() + .k(k_block).iterations(1) + , test_func, isa_check) + .loop_n(1, nr) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile_m", + tester.clone() + .n(nr).k(k_block).iterations(1) + , test_func, isa_check) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile_n", + tester.clone() + .m(mr).k(k_block).iterations(1) + , test_func, isa_check) + .loop_n(1, nr)); + if (k_block > 1) { + gemm_tests.push_back(GemmTestParams( + "k_lt_" + akbs, + tester.clone() + .m(mr).n(nr) + , test_func, isa_check) + .loop_k(1, adj_k_block - 1)); + gemm_tests.push_back(GemmTestParams( + "k_lt_" + akbs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_k(1, adj_k_block - 1) + .loop_n(1, nr) + .loop_m(1, mr)); + } + gemm_tests.push_back(GemmTestParams( + "k_gt_" + akbs, + tester.clone() + .m(mr).n(nr) + , test_func, isa_check) + .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); + gemm_tests.push_back(GemmTestParams( + "k_gt_" + akbs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block) + .loop_n(1, nr) + .loop_m(1, mr)); + if (k_block > 1) { + gemm_tests.push_back(GemmTestParams( + "k_div_" + kbs, + tester.clone() + .m(mr).n(nr) + , test_func, isa_check) + .loop_k(adj_k_block + k_block, k_block * 5, k_block)); + gemm_tests.push_back(GemmTestParams( + "k_div_" + kbs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_k(adj_k_block + k_block, k_block * 5, k_block) + .loop_n(1, nr) + .loop_m(1, mr)); + } + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs, + tester.clone() + .m(mr) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block + 1) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs, + tester.clone() + .m(mr) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block + 1) + .loop_m(1, mr)); + if (is_igemm) { + gemm_tests.push_back(GemmTestParams( + "small_kernel", + tester.clone() + .m(mr).n(nr).ks(3) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "small_kernel_subtile", + tester.clone() + .ks(3).iterations(1) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1) + .loop_n(1, nr) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs + "_small_kernel", + tester.clone() + .m(mr).ks(3) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs + "_small_kernel", + tester.clone() + .m(mr).ks(3) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block + 1)); + } + gemm_tests.push_back(GemmTestParams( + "strided_cm_subtile", + tester.clone() + .mr(mr).nr(nr).kr(kr).sr(sr) + .cm_stride(xnnpack::NextPrime(nr + 1)) + .iterations(1) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1) + .loop_n(1, nr) + .loop_m(1, mr)); + if (is_igemm) { + gemm_tests.push_back(GemmTestParams( + "a_offset", + tester.clone() + .m(mr).n(nr).ks(3) + .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1)) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "zero", + tester.clone() + .m(mr).n(nr).ks(3) + .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1)) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1) + .loop_zi(0, mr - 1)); + } + gemm_tests.push_back(GemmTestParams( + "qmin", + tester.clone() + .m(mr).n(nr).k(k_block).qmin(128) + , test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "qmax", + tester.clone() + .m(mr).n(nr).k(k_block).qmax(128) + , test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "strided_cm", + tester.clone() + .m(mr).n(nr).k(k_block) + .cm_stride(xnnpack::NextPrime(nr + 1)) + , test_func, isa_check)); + + return gemm_tests; +} + +} // namespace + + +#if XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QC8W_GEMM_MINMAX_1X4C4__AARCH64_NEONDOT, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/4, /*kr=*/4, /*sr=*/1, + /*mr_packed=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test_QP8F32QC8W(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QC8W_GEMM_MINMAX_1X4C8__AARCH64_NEONDOT, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, + /*mr_packed=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test_QP8F32QC8W(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QC8W_GEMM_MINMAX_16X4C4__AARCH64_NEONDOT_MSTEP4, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/16, /*nr=*/4, /*kr=*/4, /*sr=*/1, + /*mr_packed=*/4, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test_QP8F32QC8W(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_DOT; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_DOTPROD && XNN_ARCH_ARM64 + + +#if XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 + #if XNN_ENABLE_KLEIDIAI + INSTANTIATE_TEST_SUITE_P( + QP8_F32_QC8W_GEMM_MINMAX_16X4C8__NEONI8MM_MSTEP4, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/16, /*nr=*/4, /*kr=*/8, /*sr=*/1, + /*mr_packed=*/4, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test_QP8F32QC8W(xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4, + xnn_init_f32_minmax_scalar_params, + xnn_pack_kai_qs8_weights_and_biases, + xnn_packed_stride_kai_qs8_weights_and_biases); + }, + []() { + TEST_REQUIRES_ARM_NEON_I8MM; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + #endif // XNN_ENABLE_KLEIDIAI +#endif // XNN_ENABLE_ARM_I8MM && XNN_ARCH_ARM64 diff --git a/test/qp8-f32-qc8w-gemm-minmax.yaml b/test/qp8-f32-qc8w-gemm-minmax.yaml new file mode 100644 index 000000000000..7b8ff3896376 --- /dev/null +++ b/test/qp8-f32-qc8w-gemm-minmax.yaml @@ -0,0 +1,30 @@ +# Copyright 2023 Google LLC +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Arm KleidiAI kernels +- name: xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c4__aarch64_neondot + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_kai_qs8_weights_and_biases + packed-stride: xnn_packed_stride_kai_qs8_weights_and_biases + k-block: 1 + cpp-check: XNN_ENABLE_KLEIDIAI +- name: xnn_qp8_f32_qc8w_gemm_minmax_ukernel_1x4c8__aarch64_neondot + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_kai_qs8_weights_and_biases + packed-stride: xnn_packed_stride_kai_qs8_weights_and_biases + k-block: 1 + cpp-check: XNN_ENABLE_KLEIDIAI +- name: xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c4__aarch64_neondot_mstep4 + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_kai_qs8_weights_and_biases + packed-stride: xnn_packed_stride_kai_qs8_weights_and_biases + k-block: 1 + cpp-check: XNN_ENABLE_KLEIDIAI +- name: xnn_qp8_f32_qc8w_gemm_minmax_ukernel_16x4c8__neoni8mm_mstep4 + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_kai_qs8_weights_and_biases + packed-stride: xnn_packed_stride_kai_qs8_weights_and_biases + k-block: 1 + cpp-check: XNN_ENABLE_KLEIDIAI diff --git a/test/qs8-qc8w-gemm-minmax-fp32-2.cc b/test/qs8-qc8w-gemm-minmax-fp32-2.cc index 45f04f0bb8b5..7f33561cc713 100644 --- a/test/qs8-qc8w-gemm-minmax-fp32-2.cc +++ b/test/qs8-qc8w-gemm-minmax-fp32-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -323,6 +303,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -343,6 +324,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -363,6 +345,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -383,6 +366,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -403,6 +387,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -426,6 +411,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_cortex_a35, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -466,6 +453,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_cortex_a35_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -489,6 +477,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -509,6 +498,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -529,6 +519,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c16__asm_aarch64_neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -549,6 +540,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -572,6 +564,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_cortex_a55, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -595,6 +588,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x1c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -615,6 +609,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x2c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -638,6 +633,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -658,6 +654,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -681,6 +678,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -701,6 +699,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -721,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -741,6 +741,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -764,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -784,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -807,6 +810,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -827,6 +831,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -847,6 +852,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -867,6 +873,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -887,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -907,6 +915,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2s4__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -927,6 +936,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -947,6 +957,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -967,6 +978,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -987,6 +999,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1007,6 +1020,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1027,6 +1041,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1050,6 +1065,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1070,6 +1086,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1093,6 +1110,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1113,6 +1131,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1133,6 +1152,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1153,6 +1173,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1173,6 +1194,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1193,6 +1215,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1213,6 +1236,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1233,6 +1257,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1253,6 +1278,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1273,6 +1299,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1293,6 +1320,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1313,6 +1341,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1333,6 +1362,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1353,6 +1383,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1373,6 +1404,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1393,6 +1425,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1413,6 +1446,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1433,6 +1467,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1453,6 +1488,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1473,6 +1509,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1493,6 +1530,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1513,6 +1551,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1533,6 +1572,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1553,6 +1593,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1573,6 +1614,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1593,6 +1635,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1616,6 +1659,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1639,6 +1683,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1659,6 +1704,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1679,6 +1725,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1702,6 +1749,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1722,6 +1770,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1742,6 +1791,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_10x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1762,6 +1812,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_14x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1782,6 +1833,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1802,6 +1854,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1822,6 +1875,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_12x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1842,6 +1896,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1862,6 +1917,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1882,6 +1938,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1902,6 +1959,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_9x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1922,6 +1980,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1945,6 +2004,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1965,6 +2025,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1985,6 +2046,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_10x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2008,6 +2070,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2028,6 +2091,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2048,6 +2112,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2068,6 +2133,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2091,6 +2157,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avxvnniint8_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2111,6 +2178,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avxvnniint8_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2134,6 +2202,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2151,6 +2220,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2168,6 +2238,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2185,6 +2256,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2202,6 +2274,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2219,6 +2292,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2236,6 +2310,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2256,6 +2331,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2276,6 +2352,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2296,6 +2373,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2316,6 +2394,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2336,6 +2415,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2356,6 +2436,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2376,6 +2457,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2396,6 +2478,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2416,6 +2499,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__wasmusdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2436,6 +2520,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2456,6 +2541,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2476,6 +2562,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__wasmusdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2496,6 +2583,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2516,6 +2604,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__wasmsdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2536,6 +2625,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2556,6 +2646,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2576,6 +2667,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2599,6 +2691,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2616,6 +2709,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2633,6 +2727,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2652,6 +2747,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2670,6 +2766,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2688,6 +2785,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2706,6 +2804,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2724,6 +2823,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2742,6 +2842,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2760,6 +2861,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, diff --git a/test/qs8-qc8w-gemm-minmax-fp32-3.cc b/test/qs8-qc8w-gemm-minmax-fp32-3.cc index 15afb111969f..72bd467bbf50 100644 --- a/test/qs8-qc8w-gemm-minmax-fp32-3.cc +++ b/test/qs8-qc8w-gemm-minmax-fp32-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x64c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -326,6 +306,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -346,6 +327,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -366,6 +348,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__asm_aarch32_neonv8_mlal_lane_cortex_a35, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -386,6 +369,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -406,6 +390,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -426,6 +411,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_ld64_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -449,6 +435,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -469,6 +456,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c4__asm_aarch32_neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -492,6 +480,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -515,6 +504,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__asm_aarch64_neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -538,6 +528,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -558,6 +549,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -578,6 +570,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -601,6 +594,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -621,6 +615,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -641,6 +636,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -661,6 +657,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -681,6 +678,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -704,6 +702,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -724,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -744,6 +744,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -764,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -784,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2s4__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -804,6 +807,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -824,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -844,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -864,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -887,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -910,6 +918,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__aarch64_neondot_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -933,6 +942,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -953,6 +963,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -973,6 +984,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -993,6 +1005,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1013,6 +1026,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1033,6 +1047,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1053,6 +1068,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1073,6 +1089,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1093,6 +1110,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1113,6 +1131,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1136,6 +1155,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1159,6 +1179,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1182,6 +1203,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1205,6 +1227,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1225,6 +1248,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1245,6 +1269,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1268,6 +1293,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1288,6 +1314,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1308,6 +1335,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1328,6 +1356,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1348,6 +1377,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1368,6 +1398,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1388,6 +1419,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1408,6 +1440,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1428,6 +1461,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1448,6 +1482,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1468,6 +1503,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1488,6 +1524,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1508,6 +1545,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1528,6 +1566,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1548,6 +1587,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1568,6 +1608,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1588,6 +1629,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1608,6 +1650,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1628,6 +1671,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1648,6 +1692,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1668,6 +1713,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1691,6 +1737,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1711,6 +1758,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1734,6 +1782,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1757,6 +1806,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1777,6 +1827,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1797,6 +1848,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_9x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1817,6 +1869,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1837,6 +1890,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1857,6 +1911,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_9x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1877,6 +1932,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_10x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1897,6 +1953,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_12x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1917,6 +1974,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_14x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1937,6 +1995,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1957,6 +2016,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_9x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1977,6 +2037,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_12x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1997,6 +2058,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_14x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2020,6 +2082,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2040,6 +2103,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2060,6 +2124,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_9x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2080,6 +2145,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2100,6 +2166,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_12x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2123,6 +2190,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2143,6 +2211,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2163,6 +2232,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2183,6 +2253,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2203,6 +2274,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2223,6 +2295,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2246,6 +2319,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2263,6 +2337,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2280,6 +2355,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2297,6 +2373,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2314,6 +2391,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2331,6 +2409,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2348,6 +2427,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2365,6 +2445,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2382,6 +2463,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2399,6 +2481,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2419,6 +2502,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2439,6 +2523,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2459,6 +2544,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2479,6 +2565,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2499,6 +2586,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2519,6 +2607,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2539,6 +2628,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2559,6 +2649,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2579,6 +2670,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2599,6 +2691,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2619,6 +2712,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2639,6 +2733,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__wasmusdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2659,6 +2754,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2679,6 +2775,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2699,6 +2796,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c4__wasmsdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2719,6 +2817,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2742,6 +2841,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2759,6 +2859,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2778,6 +2879,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2796,6 +2898,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2814,6 +2917,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2832,6 +2936,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2850,6 +2955,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2868,6 +2974,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2886,6 +2993,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2904,6 +3012,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, diff --git a/test/qs8-qc8w-gemm-minmax-fp32.cc b/test/qs8-qc8w-gemm-minmax-fp32.cc index e7e87e017036..d0b4b835bf4f 100644 --- a/test/qs8-qc8w-gemm-minmax-fp32.cc +++ b/test/qs8-qc8w-gemm-minmax-fp32.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -323,6 +303,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -343,6 +324,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -363,6 +345,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x32c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -383,6 +366,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -406,6 +390,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__asm_aarch32_neonv8_mlal_lane_cortex_a35_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -426,6 +411,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a7, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -466,6 +453,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -486,6 +474,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -506,6 +495,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -526,6 +516,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -549,6 +540,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -569,6 +561,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -592,6 +585,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__asm_aarch64_neondot_ld32, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -615,6 +609,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -635,6 +630,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mull, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -655,6 +651,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -675,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -698,6 +696,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_ld32, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -718,6 +717,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -738,6 +738,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -761,6 +762,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x2c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -781,6 +783,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x1c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -804,6 +807,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -824,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -844,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -864,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -884,6 +891,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -907,6 +915,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -927,6 +936,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -947,6 +957,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neon_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -967,6 +978,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -987,6 +999,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c2s4__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1007,6 +1020,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1027,6 +1041,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1047,6 +1062,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1067,6 +1083,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1087,6 +1104,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1107,6 +1125,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4s2__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1127,6 +1146,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c4s2__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1150,6 +1170,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__aarch64_neondot_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1173,6 +1194,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1196,6 +1218,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1219,6 +1242,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1239,6 +1263,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1259,6 +1284,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1279,6 +1305,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1299,6 +1326,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neon_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1319,6 +1347,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1339,6 +1368,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1359,6 +1389,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c2s4__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1379,6 +1410,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1399,6 +1431,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4s2__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1419,6 +1452,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c4s2__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1439,6 +1473,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1459,6 +1494,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1479,6 +1515,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1499,6 +1536,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1519,6 +1557,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1539,6 +1578,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1559,6 +1599,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1579,6 +1620,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1599,6 +1641,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1619,6 +1662,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1639,6 +1683,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1659,6 +1704,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1679,6 +1725,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1699,6 +1746,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1719,6 +1767,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1739,6 +1788,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1759,6 +1809,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1782,6 +1833,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1802,6 +1854,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1825,6 +1878,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1845,6 +1899,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1865,6 +1920,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1885,6 +1941,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1905,6 +1962,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1925,6 +1983,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1945,6 +2004,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1965,6 +2025,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1985,6 +2046,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2005,6 +2067,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2025,6 +2088,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2045,6 +2109,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2065,6 +2130,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2085,6 +2151,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2105,6 +2172,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2125,6 +2193,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2145,6 +2214,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2165,6 +2235,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2185,6 +2256,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2205,6 +2277,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2225,6 +2298,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2245,6 +2319,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2265,6 +2340,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2288,6 +2364,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2311,6 +2388,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2331,6 +2409,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2351,6 +2430,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2371,6 +2451,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2394,6 +2475,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2414,6 +2496,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_12x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2434,6 +2517,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2454,6 +2538,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_14x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2474,6 +2559,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2494,6 +2580,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_10x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2514,6 +2601,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2534,6 +2622,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2554,6 +2643,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_10x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2577,6 +2667,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_12x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2597,6 +2688,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_14x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2617,6 +2709,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2637,6 +2730,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2657,6 +2751,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2677,6 +2772,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_9x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2697,6 +2793,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_10x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2717,6 +2814,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_14x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2740,6 +2838,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2760,6 +2859,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2780,6 +2880,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2800,6 +2901,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2820,6 +2922,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2840,6 +2943,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2863,6 +2967,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2880,6 +2985,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2897,6 +3003,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2914,6 +3021,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2931,6 +3039,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2948,6 +3057,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2965,6 +3075,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2985,6 +3096,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3005,6 +3117,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3025,6 +3138,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3045,6 +3159,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3065,6 +3180,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3085,6 +3201,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3105,6 +3222,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3125,6 +3243,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3145,6 +3264,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3165,6 +3285,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3185,6 +3306,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3205,6 +3327,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3225,6 +3348,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3245,6 +3369,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3265,6 +3390,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3285,6 +3411,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c4__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3305,6 +3432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x16c4__wasmsdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3328,6 +3456,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3345,6 +3474,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3362,6 +3492,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3381,6 +3512,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3399,6 +3531,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3417,6 +3550,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3435,6 +3569,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3453,6 +3588,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3471,6 +3607,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3489,6 +3626,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3507,6 +3645,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3525,6 +3664,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, diff --git a/test/qs8-qc8w-igemm-minmax-fp32-2.cc b/test/qs8-qc8w-igemm-minmax-fp32-2.cc index c959044f1271..2df0e39913b5 100644 --- a/test/qs8-qc8w-igemm-minmax-fp32-2.cc +++ b/test/qs8-qc8w-igemm-minmax-fp32-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -323,6 +303,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -343,6 +324,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -366,6 +348,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -386,6 +369,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -406,6 +390,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__asm_aarch32_neonv8_mlal_lane_cortex_a35, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -426,6 +411,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__asm_aarch32_neonv8_mlal_lane_cortex_a35_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -466,6 +453,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -486,6 +474,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -506,6 +495,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -526,6 +516,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_ld64_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -549,6 +540,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -569,6 +561,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -589,6 +582,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -609,6 +603,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -632,6 +627,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -655,6 +651,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x1c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -678,6 +675,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -698,6 +696,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -721,6 +720,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -741,6 +741,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -761,6 +762,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -781,6 +783,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neon_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -801,6 +804,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -821,6 +825,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2s4__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -841,6 +846,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -861,6 +867,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4s2__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -881,6 +888,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -901,6 +909,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -921,6 +930,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -941,6 +951,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -961,6 +972,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2s4__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -981,6 +993,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1001,6 +1014,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1021,6 +1035,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1041,6 +1056,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1061,6 +1077,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1081,6 +1098,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1101,6 +1119,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1121,6 +1140,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1141,6 +1161,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1161,6 +1182,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1184,6 +1206,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1207,6 +1230,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1230,6 +1254,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1250,6 +1275,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1270,6 +1296,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1290,6 +1317,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1310,6 +1338,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1330,6 +1359,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1350,6 +1380,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1370,6 +1401,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1390,6 +1422,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1410,6 +1443,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1430,6 +1464,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1450,6 +1485,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1470,6 +1506,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1490,6 +1527,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1510,6 +1548,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1530,6 +1569,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1550,6 +1590,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1570,6 +1611,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1590,6 +1632,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1610,6 +1653,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1630,6 +1674,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1653,6 +1698,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1673,6 +1719,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1693,6 +1740,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1713,6 +1761,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1736,6 +1785,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1756,6 +1806,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1776,6 +1827,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1796,6 +1848,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_9x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1816,6 +1869,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1836,6 +1890,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1856,6 +1911,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1876,6 +1932,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1896,6 +1953,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_14x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1916,6 +1974,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1936,6 +1995,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_10x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1956,6 +2016,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_14x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1979,6 +2040,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1999,6 +2061,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2019,6 +2082,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2039,6 +2103,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_12x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2059,6 +2124,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_14x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2079,6 +2145,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2099,6 +2166,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2119,6 +2187,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2139,6 +2208,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_10x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2162,6 +2232,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2182,6 +2253,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2202,6 +2274,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2222,6 +2295,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2242,6 +2316,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2265,6 +2340,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2282,6 +2358,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2299,6 +2376,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2316,6 +2394,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2333,6 +2412,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2350,6 +2430,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2367,6 +2448,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2384,6 +2466,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2401,6 +2484,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2418,6 +2502,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2435,6 +2520,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2452,6 +2538,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2472,6 +2559,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2492,6 +2580,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2512,6 +2601,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2532,6 +2622,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2552,6 +2643,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2572,6 +2664,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2592,6 +2685,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2612,6 +2706,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2632,6 +2727,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2652,6 +2748,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2672,6 +2769,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2692,6 +2790,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2712,6 +2811,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2732,6 +2832,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2752,6 +2853,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2772,6 +2874,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2792,6 +2895,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2812,6 +2916,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2832,6 +2937,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__wasmsdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2855,6 +2961,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2872,6 +2979,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2889,6 +2997,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2906,6 +3015,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2925,6 +3035,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2943,6 +3054,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2961,6 +3073,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2979,6 +3092,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2997,6 +3111,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3015,6 +3130,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3033,6 +3149,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, diff --git a/test/qs8-qc8w-igemm-minmax-fp32-3.cc b/test/qs8-qc8w-igemm-minmax-fp32-3.cc index 164429183c6e..d3afc906b725 100644 --- a/test/qs8-qc8w-igemm-minmax-fp32-3.cc +++ b/test/qs8-qc8w-igemm-minmax-fp32-3.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -323,6 +303,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x32c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -346,6 +327,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -366,6 +348,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neonv8_mlal_lane_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -389,6 +372,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c4__asm_aarch32_neondot_cortex_a55, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -409,6 +393,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c4__asm_aarch32_neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -432,6 +417,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -452,6 +438,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -472,6 +459,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -492,6 +480,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__asm_aarch64_neon_mlal_cortex_a53_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -515,6 +504,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_cortex_a55, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -535,6 +525,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__asm_aarch64_neondot_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -558,6 +549,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x1c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -578,6 +570,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x2c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -601,6 +594,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -621,6 +615,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -641,6 +636,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -664,6 +660,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -684,6 +681,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -704,6 +702,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -724,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -744,6 +744,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -764,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -784,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4s2__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -807,6 +810,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__aarch64_neondot_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -830,6 +834,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -853,6 +858,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -876,6 +882,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -896,6 +903,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -916,6 +924,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -936,6 +945,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -956,6 +966,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -976,6 +987,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -996,6 +1008,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1016,6 +1029,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neon_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1036,6 +1050,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1056,6 +1071,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2s4__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1076,6 +1092,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1096,6 +1113,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1116,6 +1134,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4__neonv8_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1136,6 +1155,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4s2__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1156,6 +1176,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1176,6 +1197,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1196,6 +1218,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1216,6 +1239,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1239,6 +1263,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1262,6 +1287,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1282,6 +1308,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1302,6 +1329,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1325,6 +1353,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1345,6 +1374,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1368,6 +1398,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1388,6 +1419,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1408,6 +1440,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1428,6 +1461,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1448,6 +1482,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1468,6 +1503,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1488,6 +1524,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1508,6 +1545,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1528,6 +1566,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1548,6 +1587,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1568,6 +1608,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1588,6 +1629,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1608,6 +1650,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1628,6 +1671,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1648,6 +1692,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1668,6 +1713,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1688,6 +1734,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1708,6 +1755,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1728,6 +1776,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1748,6 +1797,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1768,6 +1818,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1788,6 +1839,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1811,6 +1863,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1831,6 +1884,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1854,6 +1908,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1877,6 +1932,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_12x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1897,6 +1953,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_14x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1917,6 +1974,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1937,6 +1995,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/14, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_14x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1957,6 +2016,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1977,6 +2037,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_10x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1997,6 +2058,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_12x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2017,6 +2079,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2037,6 +2100,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2057,6 +2121,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_9x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2080,6 +2145,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/10, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_10x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2103,6 +2169,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2123,6 +2190,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2143,6 +2211,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2166,6 +2235,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avxvnniint8_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2186,6 +2256,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x8c8__avxvnniint8_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2209,6 +2280,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2226,6 +2298,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2243,6 +2316,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2260,6 +2334,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2277,6 +2352,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2294,6 +2370,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2314,6 +2391,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2334,6 +2412,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2354,6 +2433,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2374,6 +2454,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2394,6 +2475,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2414,6 +2496,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c16__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2434,6 +2517,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2454,6 +2538,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2474,6 +2559,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2494,6 +2580,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2514,6 +2601,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2534,6 +2622,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2554,6 +2643,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2574,6 +2664,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2594,6 +2685,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__wasmusdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2614,6 +2706,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__wasmusdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2634,6 +2727,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__wasmusdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2654,6 +2748,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2674,6 +2769,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__wasmsdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2694,6 +2790,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__wasmsdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2717,6 +2814,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2734,6 +2832,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2753,6 +2852,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2771,6 +2871,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2789,6 +2890,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2807,6 +2909,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2825,6 +2928,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2843,6 +2947,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2861,6 +2966,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2879,6 +2985,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2897,6 +3004,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, diff --git a/test/qs8-qc8w-igemm-minmax-fp32.cc b/test/qs8-qc8w-igemm-minmax-fp32.cc index 3a217467f016..92a12e6930bb 100644 --- a/test/qs8-qc8w-igemm-minmax-fp32.cc +++ b/test/qs8-qc8w-igemm-minmax-fp32.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -303,6 +282,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x16c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -323,6 +303,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x32c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -343,6 +324,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -363,6 +345,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x32c4__avx512amx_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -383,6 +366,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/1, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -403,6 +387,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/7, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x64c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -423,6 +408,7 @@ std::vector CreateTests1( /*adj_k_block=*/64, /*mr=*/16, /*nr=*/64, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -446,6 +432,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -469,6 +456,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__asm_aarch64_neon_mlal_cortex_a53, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -489,6 +477,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c16__asm_aarch64_neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -509,6 +498,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -532,6 +522,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x2c4__armsimd32, xnn_init_qs8_qc8w_conv_minmax_fp32_armsimd32_params, @@ -555,6 +546,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -575,6 +567,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -595,6 +588,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -615,6 +609,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -635,6 +630,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -655,6 +651,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -675,6 +672,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x8c8__neoni8mm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -698,6 +696,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -718,6 +717,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2__neonv8_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -738,6 +738,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c2s4__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -758,6 +759,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -778,6 +780,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -801,6 +804,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -824,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neonv8_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -844,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c4__neonv8_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -864,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -887,6 +894,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -910,6 +918,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__aarch64_neondot_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -933,6 +942,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__neondot_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -956,6 +966,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -976,6 +987,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -996,6 +1008,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c2__neonv8_mlal_ld4r, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1016,6 +1029,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4__neon_mlal_dup, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1036,6 +1050,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4__neon_mlal_ld1r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1056,6 +1071,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4__neon_mlal_ld2r, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1076,6 +1092,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/4, /*sr=*/2, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c4s2__neonv8_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1096,6 +1113,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__neon_mlal, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1116,6 +1134,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1136,6 +1155,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1156,6 +1176,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1176,6 +1197,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1196,6 +1218,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1216,6 +1239,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1236,6 +1260,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8__neonv8_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1256,6 +1281,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16__neonv8_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1279,6 +1305,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1302,6 +1329,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1322,6 +1350,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16__neon_mlal_lane, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1342,6 +1371,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x16__neon_mlal_lane_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_neon_params, @@ -1365,6 +1395,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__neondot, xnn_init_qs8_qc8w_conv_minmax_fp32_neonv8_params, @@ -1388,6 +1419,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1408,6 +1440,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1428,6 +1461,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1448,6 +1482,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1468,6 +1503,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1488,6 +1524,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1508,6 +1545,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1528,6 +1566,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1548,6 +1587,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1568,6 +1608,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1588,6 +1629,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1608,6 +1650,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1628,6 +1671,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1648,6 +1692,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1668,6 +1713,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2s4__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1688,6 +1734,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1708,6 +1755,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1728,6 +1776,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1748,6 +1797,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1768,6 +1818,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1788,6 +1839,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1808,6 +1860,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__sse41_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1828,6 +1881,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__avx_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1848,6 +1902,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__sse41_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1868,6 +1923,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1888,6 +1944,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c8__avx_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1908,6 +1965,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__avx2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1931,6 +1989,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1951,6 +2010,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x8c8__avx256skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1974,6 +2034,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -1994,6 +2055,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2014,6 +2076,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512skx, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2037,6 +2100,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2057,6 +2121,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2077,6 +2142,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_10x16c4__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2097,6 +2163,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2117,6 +2184,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/9, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_9x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2137,6 +2205,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/10, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_10x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2157,6 +2226,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/12, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_12x16c4__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2177,6 +2247,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2197,6 +2268,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2217,6 +2289,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_9x16c8__avx512vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2237,6 +2310,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2257,6 +2331,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_12x16c8__avx512vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2280,6 +2355,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2300,6 +2376,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_9x8c8__avx256vnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2320,6 +2397,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2340,6 +2418,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/9, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_9x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2360,6 +2439,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/12, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_12x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2380,6 +2460,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/14, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_14x8c8__avx256vnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2403,6 +2484,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2423,6 +2505,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2443,6 +2526,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/5, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_5x8c8__avxvnni, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2463,6 +2547,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2483,6 +2568,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2503,6 +2589,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/6, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_6x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2523,6 +2610,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/7, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2543,6 +2631,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/8, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_8x8c8__avxvnni_prfm, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2566,6 +2655,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2583,6 +2673,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2600,6 +2691,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld64, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2617,6 +2709,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2634,6 +2727,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2651,6 +2745,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld128, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2671,6 +2766,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/4, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x4c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2691,6 +2787,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/3, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2711,6 +2808,7 @@ std::vector CreateTests1( /*adj_k_block=*/16, /*mr=*/4, /*nr=*/8, /*kr=*/16, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c16__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2731,6 +2829,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2751,6 +2850,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2771,6 +2871,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2791,6 +2892,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__wasmusdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2811,6 +2913,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__wasmusdot_u2_acc2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2831,6 +2934,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x16c4__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2851,6 +2955,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__wasmsdot, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2871,6 +2976,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x16c4__wasmsdot_u2, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2894,6 +3000,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x2__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2911,6 +3018,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x4__wasm_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2930,6 +3038,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2948,6 +3057,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2966,6 +3076,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -2984,6 +3095,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x4__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3002,6 +3114,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_2x2__scalar_lrintf, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3020,6 +3133,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_3x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3038,6 +3152,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x2__scalar_fmagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, @@ -3056,6 +3171,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x2__scalar_imagic, xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params, diff --git a/test/qu8-gemm-minmax-fp32-2.cc b/test/qu8-gemm-minmax-fp32-2.cc index a1c5953f6029..dae5bfe4d981 100644 --- a/test/qu8-gemm-minmax-fp32-2.cc +++ b/test/qu8-gemm-minmax-fp32-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x1c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x1c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -382,6 +363,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -405,6 +387,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -425,6 +408,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -445,6 +429,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -465,6 +450,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -485,6 +471,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -505,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -525,6 +513,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -545,6 +534,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -565,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -585,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -605,6 +597,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -625,6 +618,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -645,6 +639,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -665,6 +660,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -685,6 +681,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -705,6 +702,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -725,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -745,6 +744,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -765,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -785,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -805,6 +807,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -825,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -845,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -865,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -885,6 +891,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -905,6 +912,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -925,6 +933,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -945,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -965,6 +975,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -985,6 +996,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1005,6 +1017,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1025,6 +1038,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1045,6 +1059,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1065,6 +1080,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1088,6 +1104,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1108,6 +1125,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1131,6 +1149,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1148,6 +1167,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1165,6 +1185,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1182,6 +1203,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1199,6 +1221,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1216,6 +1239,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1233,6 +1257,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1250,6 +1275,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1270,6 +1296,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1287,6 +1314,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1304,6 +1332,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1321,6 +1350,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1338,6 +1368,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1355,6 +1386,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1374,6 +1406,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1392,6 +1425,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1410,6 +1444,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1428,6 +1463,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1446,6 +1482,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1464,6 +1501,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1482,6 +1520,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1500,6 +1539,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1518,6 +1558,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1536,6 +1577,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1554,6 +1596,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1572,6 +1615,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, diff --git a/test/qu8-gemm-minmax-fp32.cc b/test/qu8-gemm-minmax-fp32.cc index 00e53b38f021..e1c62e393e5d 100644 --- a/test/qu8-gemm-minmax-fp32.cc +++ b/test/qu8-gemm-minmax-fp32.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x2c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x2c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -382,6 +363,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -402,6 +384,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -422,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16__neonv8_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neonv8_params, @@ -442,6 +426,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -462,6 +447,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16__neonv8_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neonv8_params, @@ -485,6 +471,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -505,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -525,6 +513,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -545,6 +534,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -565,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -585,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -605,6 +597,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -625,6 +618,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -645,6 +639,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -665,6 +660,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -685,6 +681,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -705,6 +702,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -725,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -745,6 +744,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -765,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -785,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -805,6 +807,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -825,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -845,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -865,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -885,6 +891,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -905,6 +912,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -925,6 +933,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -945,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -965,6 +975,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -985,6 +996,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1005,6 +1017,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1025,6 +1038,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1045,6 +1059,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1065,6 +1080,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1085,6 +1101,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1105,6 +1122,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1125,6 +1143,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1145,6 +1164,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1165,6 +1185,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1185,6 +1206,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1208,6 +1230,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x8c8__avx256skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1231,6 +1254,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1251,6 +1275,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1271,6 +1296,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1291,6 +1317,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1311,6 +1338,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1331,6 +1359,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1354,6 +1383,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1371,6 +1401,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1388,6 +1419,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1405,6 +1437,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1422,6 +1455,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1439,6 +1473,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1456,6 +1491,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1473,6 +1509,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1490,6 +1527,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1507,6 +1545,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1524,6 +1563,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1541,6 +1581,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1558,6 +1599,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1575,6 +1617,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1592,6 +1635,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1609,6 +1653,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1629,6 +1674,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1646,6 +1692,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1665,6 +1712,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1683,6 +1731,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1701,6 +1750,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1719,6 +1769,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1737,6 +1788,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1755,6 +1807,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1773,6 +1826,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1791,6 +1845,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1809,6 +1864,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1827,6 +1883,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1845,6 +1902,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1863,6 +1921,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, diff --git a/test/qu8-gemm-minmax-rndnu-2.cc b/test/qu8-gemm-minmax-rndnu-2.cc index 11e031265d33..3590463ba692 100644 --- a/test/qu8-gemm-minmax-rndnu-2.cc +++ b/test/qu8-gemm-minmax-rndnu-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a7, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -382,6 +363,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a75, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -402,6 +384,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -422,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -445,6 +429,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -465,6 +450,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -485,6 +471,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -505,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_6x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -527,6 +515,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x2__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, @@ -545,6 +534,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x4__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, @@ -563,6 +553,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x2__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, @@ -581,6 +572,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x4__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, diff --git a/test/qu8-gemm-minmax-rndnu.cc b/test/qu8-gemm-minmax-rndnu.cc index 8ed4458de22c..89c74108a22d 100644 --- a/test/qu8-gemm-minmax-rndnu.cc +++ b/test/qu8-gemm-minmax-rndnu.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -379,6 +360,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -399,6 +381,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -419,6 +402,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -439,6 +423,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -462,6 +447,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -482,6 +468,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -502,6 +489,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu16_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, xnn_init_qu8_conv_minmax_rndnu16_scalar_params, @@ -522,6 +510,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a75_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -545,6 +534,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -565,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu16_ukernel_1x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu16_scalar_params, @@ -585,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -605,6 +597,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -625,6 +618,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -645,6 +639,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -665,6 +660,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_6x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -687,6 +683,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x2__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, @@ -705,6 +702,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x4__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, @@ -723,6 +721,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x2__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, @@ -741,6 +740,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x4__scalar, xnn_init_qu8_conv_minmax_rndnu_scalar_params, diff --git a/test/qu8-gemm-minmax-rndnu16.cc b/test/qu8-gemm-minmax-rndnu16.cc index 35679946dcbd..b1f9b06705bc 100644 --- a/test/qu8-gemm-minmax-rndnu16.cc +++ b/test/qu8-gemm-minmax-rndnu16.cc @@ -32,8 +32,10 @@ namespace { std::vector CreateTests1( - size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, - size_t sr, bool is_igemm, + size_t k_block, size_t adj_k_block, + size_t mr, size_t nr, size_t kr, size_t sr, + bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -41,250 +43,975 @@ std::vector CreateTests1( std::string akbs = std::to_string(adj_k_block); std::string nrs = std::to_string(nr); - const GemmMicrokernelTester tester = - GemmMicrokernelTester().mr(mr).nr(nr).kr(kr).sr(sr); + const GemmMicrokernelTester tester = GemmMicrokernelTester() + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); - gemm_tests.push_back(GemmTestParams("k_eq_" + kbs, - tester.clone().m(mr).n(nr).k(k_block), - test_func, isa_check)); - gemm_tests.push_back( - GemmTestParams("strided_cn", - tester.clone().m(mr).n(nr).k(k_block).cn_stride( - xnnpack::NextPrime(nr + 1)), - test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs, + tester.clone() + .m(mr).n(nr).k(k_block) + , test_func, isa_check)); if (!is_igemm) { - gemm_tests.push_back( - GemmTestParams("k_eq_" + kbs + "_strided_a", - tester.clone().m(mr).n(nr).k(k_block).a_stride( - xnnpack::NextPrime(k_block + 1)), - test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_strided_a", + tester.clone() + .m(mr).n(nr).k(k_block) + .a_stride(xnnpack::NextPrime(k_block + 1)) + , test_func, isa_check)); } - gemm_tests.push_back(GemmTestParams("k_eq_" + kbs + "_subtile", - tester.clone().k(k_block).iterations(1), - test_func, isa_check) - .loop_n(1, nr) - .loop_m(1, mr)); - gemm_tests.push_back( - GemmTestParams("k_eq_" + kbs + "_subtile_m", - tester.clone().n(nr).k(k_block).iterations(1), test_func, - isa_check) - .loop_m(1, mr)); - gemm_tests.push_back( - GemmTestParams("k_eq_" + kbs + "_subtile_n", - tester.clone().m(mr).k(k_block).iterations(1), test_func, - isa_check) - .loop_n(1, nr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile", + tester.clone() + .k(k_block).iterations(1) + , test_func, isa_check) + .loop_n(1, nr) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile_m", + tester.clone() + .n(nr).k(k_block).iterations(1) + , test_func, isa_check) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_subtile_n", + tester.clone() + .m(mr).k(k_block).iterations(1) + , test_func, isa_check) + .loop_n(1, nr)); if (k_block > 1) { - gemm_tests.push_back(GemmTestParams("k_lt_" + akbs, - tester.clone().m(mr).n(nr), test_func, - isa_check) - .loop_k(1, adj_k_block - 1)); + gemm_tests.push_back(GemmTestParams( + "k_lt_" + akbs, + tester.clone() + .m(mr).n(nr) + , test_func, isa_check) + .loop_k(1, adj_k_block - 1)); if (!is_igemm) { - gemm_tests.push_back( - GemmTestParams("k_lt_" + akbs + "_strided_a", - tester.clone().m(mr).n(nr).a_stride( - xnnpack::NextPrime(adj_k_block + 1)), - test_func, isa_check) - .loop_k(1, adj_k_block - 1)); + gemm_tests.push_back(GemmTestParams( + "k_lt_" + akbs + "_strided_a", + tester.clone() + .m(mr).n(nr) + .a_stride(xnnpack::NextPrime(adj_k_block + 1)) + , test_func, isa_check) + .loop_k(1, adj_k_block - 1)); } - gemm_tests.push_back(GemmTestParams("k_lt_" + akbs + "_subtile", - tester.clone().iterations(1), test_func, - isa_check) - .loop_k(1, adj_k_block - 1) - .loop_n(1, nr) - .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_lt_" + akbs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_k(1, adj_k_block - 1) + .loop_n(1, nr) + .loop_m(1, mr)); } - gemm_tests.push_back( - GemmTestParams("k_gt_" + akbs, tester.clone().m(mr).n(nr), test_func, - isa_check) - .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); + gemm_tests.push_back(GemmTestParams( + "k_gt_" + akbs, + tester.clone() + .m(mr).n(nr) + , test_func, isa_check) + .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); if (is_igemm) { - gemm_tests.push_back( - GemmTestParams("k_gt_" + akbs + "_strided_a", - tester.clone().m(mr).n(nr).a_stride( - xnnpack::NextPrime(adj_k_block * 2 + 1)), - test_func, isa_check) - .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); + gemm_tests.push_back(GemmTestParams( + "k_gt_" + akbs + "_strided_a", + tester.clone() + .m(mr).n(nr) + .a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1)) + , test_func, isa_check) + .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); } - gemm_tests.push_back( - GemmTestParams("k_gt_" + akbs + "_subtile", tester.clone().iterations(1), - test_func, isa_check) - .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block) - .loop_n(1, nr) - .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_gt_" + akbs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block) + .loop_n(1, nr) + .loop_m(1, mr)); if (k_block > 1) { - gemm_tests.push_back( - GemmTestParams("k_div_" + kbs, tester.clone().m(mr).n(nr), test_func, - isa_check) - .loop_k(adj_k_block + k_block, k_block * 5, k_block)); + gemm_tests.push_back(GemmTestParams( + "k_div_" + kbs, + tester.clone() + .m(mr).n(nr) + , test_func, isa_check) + .loop_k(adj_k_block + k_block, k_block * 5, k_block)); if (is_igemm) { - gemm_tests.push_back( - GemmTestParams("k_div_" + kbs + "_strided_a", - tester.clone().m(mr).n(nr).a_stride( - xnnpack::NextPrime(k_block * 3 + 1)), - test_func, isa_check) - .loop_k(adj_k_block + k_block, k_block * 3, k_block)); + gemm_tests.push_back(GemmTestParams( + "k_div_" + kbs + "_strided_a", + tester.clone() + .m(mr).n(nr) + .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) + , test_func, isa_check) + .loop_k(adj_k_block + k_block, k_block * 3, k_block)); } - gemm_tests.push_back( - GemmTestParams("k_div_" + kbs + "_subtile", - tester.clone().iterations(1), test_func, isa_check) - .loop_k(adj_k_block + k_block, k_block * 5, k_block) - .loop_n(1, nr) - .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "k_div_" + kbs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_k(adj_k_block + k_block, k_block * 5, k_block) + .loop_n(1, nr) + .loop_m(1, mr)); } - gemm_tests.push_back( - GemmTestParams("n_gt_" + nrs, tester.clone().m(mr), test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back( - GemmTestParams("n_gt_" + nrs + "_strided_cn", - tester.clone().m(mr).cn_stride(xnnpack::NextPrime(nr + 1)), - test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs, + tester.clone() + .m(mr) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { - gemm_tests.push_back( - GemmTestParams( - "n_gt_" + nrs + "_strided_a", - tester.clone().m(mr).a_stride(xnnpack::NextPrime(k_block * 3 + 1)), - test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block)); + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs + "_strided_a", + tester.clone() + .m(mr) + .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block)); } - gemm_tests.push_back(GemmTestParams("n_gt_" + nrs + "_subtile", - tester.clone().iterations(1), test_func, - isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1) - .loop_m(1, mr)); - gemm_tests.push_back( - GemmTestParams("n_div_" + nrs, tester.clone().m(mr), test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back( - GemmTestParams("n_div_" + nrs + "_strided_cn", - tester.clone().m(mr).cn_stride(xnnpack::NextPrime(nr + 1)), - test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block + 1) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs, + tester.clone() + .m(mr) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { - gemm_tests.push_back( - GemmTestParams( - "n_div_" + nrs + "_strided_a", - tester.clone().m(mr).a_stride(xnnpack::NextPrime(k_block * 3 + 1)), - test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs + "_strided_a", + tester.clone() + .m(mr) + .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block)); } - gemm_tests.push_back(GemmTestParams("n_div_" + nrs + "_subtile", - tester.clone().iterations(1), test_func, - isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1) - .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs + "_subtile", + tester.clone() + .iterations(1) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block + 1) + .loop_m(1, mr)); if (is_igemm) { - gemm_tests.push_back(GemmTestParams("small_kernel", - tester.clone().m(mr).n(nr).ks(3), - test_func, isa_check) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams("small_kernel_subtile", - tester.clone().ks(3).iterations(1), - test_func, isa_check) - .loop_k(1, k_block * 3, k_block + 1) - .loop_n(1, nr) - .loop_m(1, mr)); - gemm_tests.push_back(GemmTestParams("n_gt_" + nrs + "_small_kernel", - tester.clone().m(mr).ks(3), test_func, - isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams("n_div_" + nrs + "_small_kernel", - tester.clone().m(mr).ks(3), test_func, - isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "small_kernel", + tester.clone() + .m(mr).n(nr).ks(3) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "small_kernel_subtile", + tester.clone() + .ks(3).iterations(1) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1) + .loop_n(1, nr) + .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs + "_small_kernel", + tester.clone() + .m(mr).ks(3) + , test_func, isa_check) + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs + "_small_kernel", + tester.clone() + .m(mr).ks(3) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block + 1)); } - gemm_tests.push_back(GemmTestParams("strided_cm_subtile", - tester.clone() - .mr(mr) - .nr(nr) - .kr(kr) - .sr(sr) - .cm_stride(xnnpack::NextPrime(nr + 1)) - .iterations(1), - test_func, isa_check) - .loop_k(1, k_block * 3, k_block + 1) - .loop_n(1, nr) - .loop_m(1, mr)); + gemm_tests.push_back(GemmTestParams( + "strided_cm_subtile", + tester.clone() + .mr(mr).nr(nr).kr(kr).sr(sr) + .cm_stride(xnnpack::NextPrime(nr + 1)) + .iterations(1) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1) + .loop_n(1, nr) + .loop_m(1, mr)); if (is_igemm) { - gemm_tests.push_back( - GemmTestParams("a_offset", - tester.clone().m(mr).n(nr).ks(3).a_offset( - xnnpack::NextPrime(mr * k_block * 3 + 1)), - test_func, isa_check) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back( - GemmTestParams("zero", - tester.clone().m(mr).n(nr).ks(3).a_offset( - xnnpack::NextPrime(mr * k_block * 3 + 1)), - test_func, isa_check) - .loop_k(1, k_block * 3, k_block + 1) - .loop_zi(0, mr - 1)); + gemm_tests.push_back(GemmTestParams( + "a_offset", + tester.clone() + .m(mr).n(nr).ks(3) + .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1)) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "zero", + tester.clone() + .m(mr).n(nr).ks(3) + .a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1)) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1) + .loop_zi(0, mr - 1)); } - gemm_tests.push_back( - GemmTestParams("qmin", tester.clone().m(mr).n(nr).k(k_block).qmin(128), - test_func, isa_check)); - gemm_tests.push_back( - GemmTestParams("qmax", tester.clone().m(mr).n(nr).k(k_block).qmax(128), - test_func, isa_check)); - gemm_tests.push_back( - GemmTestParams("strided_cm", - tester.clone().m(mr).n(nr).k(k_block).cm_stride( - xnnpack::NextPrime(nr + 1)), - test_func, isa_check)); - gemm_tests.push_back( - GemmTestParams("no_a_zero_point", - tester.clone().m(mr).n(nr).a_zero_point(0), test_func, - isa_check) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back( - GemmTestParams("no_b_zero_point", - tester.clone().m(mr).n(nr).b_zero_point(0), test_func, - isa_check) - .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams("b_zero_point", - tester.clone().m(mr).n(nr).k(k_block), - test_func, isa_check) - .loop_bzp(0, 255)); - gemm_tests.push_back( - GemmTestParams("no_zero_point", - tester.clone().m(mr).n(nr).a_zero_point(0).b_zero_point(0), - test_func, isa_check) - .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "qmin", + tester.clone() + .m(mr).n(nr).k(k_block).qmin(128) + , test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "qmax", + tester.clone() + .m(mr).n(nr).k(k_block).qmax(128) + , test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "strided_cm", + tester.clone() + .m(mr).n(nr).k(k_block) + .cm_stride(xnnpack::NextPrime(nr + 1)) + , test_func, isa_check)); + gemm_tests.push_back(GemmTestParams( + "no_a_zero_point", + tester.clone() + .m(mr).n(nr).a_zero_point(0) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "no_b_zero_point", + tester.clone() + .m(mr).n(nr).b_zero_point(0) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); + gemm_tests.push_back(GemmTestParams( + "b_zero_point", + tester.clone() + .m(mr).n(nr).k(k_block) + , test_func, isa_check) + .loop_bzp(0, 255)); + gemm_tests.push_back(GemmTestParams( + "no_zero_point", + tester.clone() + .m(mr).n(nr) + .a_zero_point(0) + .b_zero_point(0) + , test_func, isa_check) + .loop_k(1, k_block * 3, k_block + 1)); return gemm_tests; } } // namespace -#if XNN_ARCH_ARM64 + +#if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_1X8__ASM_AARCH32_NEON_MLAL_LANE_CORTEX_A7, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_1X8__ASM_AARCH32_NEON_MLAL_LANE_CORTEX_A7_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__ASM_AARCH32_NEON_MLAL_LANE_CORTEX_A7, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a7, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__ASM_AARCH32_NEON_MLAL_LANE_CORTEX_A7_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__ASM_AARCH32_NEON_MLAL_LANE_CORTEX_A53, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__ASM_AARCH32_NEON_MLAL_LANE_CORTEX_A53_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__ASM_AARCH32_NEON_MLAL_LANE_LD64, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__ASM_AARCH32_NEON_MLAL_LANE_LD64_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY + + +#if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__ASM_AARCH64_NEON_MLAL_LANE_CORTEX_A53, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__ASM_AARCH64_NEON_MLAL_LANE_CORTEX_A53_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU16_4X16__ASM_AARCH64_NEON_MLAL_LANE_CORTEX_A53_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu16_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, + xnn_init_qu8_conv_minmax_rndnu16_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu16); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__ASM_AARCH64_NEON_MLAL_LANE_CORTEX_A75, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a75, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__ASM_AARCH64_NEON_MLAL_LANE_CORTEX_A75_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a75_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__ASM_AARCH64_NEON_MLAL_LANE_LD64, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__ASM_AARCH64_NEON_MLAL_LANE_LD64_PRFM, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64_prfm, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY + + +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_1X8__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x8__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_1X16__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x16__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU16_1X16__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu16_ukernel_1x16__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu16_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu16); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_2X8__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x8__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_2X16__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x16__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_3X8__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x8__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_3X16__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x16__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X8__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x8__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X16__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x16__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_6X8__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_6x8__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_6X16__NEON_MLAL_LANE, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/8, + /*adj_k_block=*/8, + /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_6x16__neon_mlal_lane, + xnn_init_qu8_conv_minmax_rndnu_neon_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + }, + []() { + TEST_REQUIRES_ARM_NEON; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); +#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 + + INSTANTIATE_TEST_SUITE_P( - QU8_GEMM_MINMAX_RNDNU_1X16__NEON_MLAL_LANE, GemmTest, + QU8_GEMM_MINMAX_RNDNU_1X2__SCALAR, GemmTest, testing::ValuesIn(CreateTests1( - /*k_block=*/8, - /*adj_k_block=*/8, - /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/false, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { - tester.Test(xnn_qu8_gemm_minmax_rndnu16_ukernel_1x16__neon_mlal_lane, - xnn_init_qu8_conv_minmax_rndnu16_scalar_params, - xnn_pack_qu8_gemm_goi_w, xnn_qu8_requantize_rndnu16); - }, - []() { TEST_REQUIRES_ARM_NEON; })), + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x2__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_1X4__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_1x4__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_2X2__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x2__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_2X4__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_2x4__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_3X2__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x2__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_3X4__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_3x4__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X2__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x2__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + +INSTANTIATE_TEST_SUITE_P( + QU8_GEMM_MINMAX_RNDNU_4X4__SCALAR, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/1, + /*adj_k_block=*/1, + /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_qu8_gemm_minmax_rndnu_ukernel_4x4__scalar, + xnn_init_qu8_conv_minmax_rndnu_scalar_params, + xnn_pack_qu8_gemm_goi_w, + xnn_qu8_requantize_rndnu); + })), [](const testing::TestParamInfo& info) { return info.param.test_name; }); -#endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 diff --git a/test/qu8-igemm-minmax-fp32-2.cc b/test/qu8-igemm-minmax-fp32-2.cc index 668ce4440115..0e2a600ef68a 100644 --- a/test/qu8-igemm-minmax-fp32-2.cc +++ b/test/qu8-igemm-minmax-fp32-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x2c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/2, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x2c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -382,6 +363,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -402,6 +384,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -422,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -445,6 +429,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -465,6 +450,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -485,6 +471,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -505,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -525,6 +513,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -545,6 +534,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -565,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -585,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -605,6 +597,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -625,6 +618,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -645,6 +639,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -665,6 +660,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -685,6 +681,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -705,6 +702,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -725,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -745,6 +744,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -765,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -785,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -805,6 +807,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -825,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -845,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -865,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -885,6 +891,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -905,6 +912,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -925,6 +933,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -945,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -965,6 +975,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -985,6 +996,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1005,6 +1017,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1025,6 +1038,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1045,6 +1059,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1065,6 +1080,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1085,6 +1101,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1105,6 +1122,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1128,6 +1146,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x8c8__avx256skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1151,6 +1170,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_5x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1171,6 +1191,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_8x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1194,6 +1215,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1211,6 +1233,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1228,6 +1251,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1245,6 +1269,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1262,6 +1287,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1279,6 +1305,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1296,6 +1323,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1313,6 +1341,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1330,6 +1359,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1347,6 +1377,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1364,6 +1395,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1381,6 +1413,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1401,6 +1434,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1418,6 +1452,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1437,6 +1472,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1455,6 +1491,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1473,6 +1510,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1491,6 +1529,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1509,6 +1548,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1527,6 +1567,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1545,6 +1586,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1563,6 +1605,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1581,6 +1624,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1599,6 +1643,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1617,6 +1662,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1635,6 +1681,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, diff --git a/test/qu8-igemm-minmax-fp32.cc b/test/qu8-igemm-minmax-fp32.cc index 263dea0dfdc9..89dcc98c9b35 100644 --- a/test/qu8-igemm-minmax-fp32.cc +++ b/test/qu8-igemm-minmax-fp32.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/1, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x1c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/4, /*mr=*/2, /*nr=*/1, /*kr=*/4, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x1c4__armsimd32, xnn_init_qu8_conv_minmax_fp32_armsimd32_params, @@ -382,6 +363,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16__neonv8_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neonv8_params, @@ -402,6 +384,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neon_params, @@ -422,6 +405,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16__neonv8_mlal_lane, xnn_init_qu8_conv_minmax_fp32_neonv8_params, @@ -445,6 +429,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -465,6 +450,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -485,6 +471,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -505,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -525,6 +513,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -545,6 +534,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -565,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -585,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -605,6 +597,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -625,6 +618,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -645,6 +639,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -665,6 +660,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -685,6 +681,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -705,6 +702,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -725,6 +723,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -745,6 +744,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -765,6 +765,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__sse2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -785,6 +786,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__sse41_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -805,6 +807,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -825,6 +828,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__sse41_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -845,6 +849,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__sse2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -865,6 +870,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -885,6 +891,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -905,6 +912,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -925,6 +933,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -945,6 +954,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -965,6 +975,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -985,6 +996,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1005,6 +1017,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1025,6 +1038,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1045,6 +1059,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1065,6 +1080,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c8__avx_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1085,6 +1101,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c8__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1105,6 +1122,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__avx_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1125,6 +1143,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1145,6 +1164,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x8c8__avx2, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1168,6 +1188,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1188,6 +1209,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1208,6 +1230,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1228,6 +1251,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/5, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_5x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1248,6 +1272,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/7, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1268,6 +1293,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/8, /*nr=*/16, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_8x16c8__avx512skx_prfm, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1291,6 +1317,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1308,6 +1335,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1325,6 +1353,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1342,6 +1371,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1359,6 +1389,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1376,6 +1407,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1393,6 +1425,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1410,6 +1443,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1427,6 +1461,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1444,6 +1479,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/2, /*sr=*/4, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c2s4__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1461,6 +1497,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld64, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1478,6 +1515,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/4, /*kr=*/8, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4c8__wasmsimd_dot16x2_ld128, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1498,6 +1536,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1515,6 +1554,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1532,6 +1572,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1549,6 +1590,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1566,6 +1608,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x2__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1583,6 +1626,7 @@ std::vector CreateTests1( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4__wasm_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1602,6 +1646,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1620,6 +1665,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1638,6 +1684,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1656,6 +1703,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1674,6 +1722,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x2__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1692,6 +1741,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x2__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1710,6 +1760,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4__scalar_fmagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1728,6 +1779,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/2, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x4__scalar_lrintf, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1746,6 +1798,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1764,6 +1817,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/3, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1782,6 +1836,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/2, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x2__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, @@ -1800,6 +1855,7 @@ INSTANTIATE_TEST_SUITE_P( /*adj_k_block=*/1, /*mr=*/4, /*nr=*/4, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x4__scalar_imagic, xnn_init_qu8_conv_minmax_fp32_scalar_params, diff --git a/test/qu8-igemm-minmax-rndnu-2.cc b/test/qu8-igemm-minmax-rndnu-2.cc index 15c0c5d6692f..ad98db9e62ae 100644 --- a/test/qu8-igemm-minmax-rndnu-2.cc +++ b/test/qu8-igemm-minmax-rndnu-2.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -359,6 +339,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_1x8__asm_aarch32_neon_mlal_lane_cortex_a7_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -379,6 +360,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -399,6 +381,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -419,6 +402,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_ld64_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -442,6 +426,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -462,6 +447,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu16_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53_prfm, xnn_init_qu8_conv_minmax_rndnu16_scalar_params, @@ -482,6 +468,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a75, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -502,6 +489,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a75_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -522,6 +510,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -542,6 +531,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_ld64_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -565,6 +555,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_2x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -585,6 +576,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_3x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -605,6 +597,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -625,6 +618,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_6x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -645,6 +639,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/6, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_6x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, diff --git a/test/qu8-igemm-minmax-rndnu.cc b/test/qu8-igemm-minmax-rndnu.cc index 674a474bd7af..0b69ee645a07 100644 --- a/test/qu8-igemm-minmax-rndnu.cc +++ b/test/qu8-igemm-minmax-rndnu.cc @@ -35,6 +35,7 @@ std::vector CreateTests1( size_t k_block, size_t adj_k_block, size_t mr, size_t nr, size_t kr, size_t sr, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -43,7 +44,7 @@ std::vector CreateTests1( std::string nrs = std::to_string(nr); const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -53,12 +54,6 @@ std::vector CreateTests1( tester.clone() .m(mr).n(nr).k(k_block) , test_func, isa_check)); - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_strided_a", @@ -166,14 +161,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_strided_a", @@ -199,14 +186,6 @@ std::vector CreateTests1( , test_func, isa_check) .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); if (!is_igemm) { gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_strided_a", @@ -339,6 +318,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x8__asm_aarch32_neon_mlal_lane_cortex_a53_prfm, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -362,6 +342,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -382,6 +363,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu16_ukernel_4x16__asm_aarch64_neon_mlal_lane_cortex_a53, xnn_init_qu8_conv_minmax_rndnu16_scalar_params, @@ -405,6 +387,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_1x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -425,6 +408,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_1x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -445,6 +429,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/1, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu16_ukernel_1x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu16_scalar_params, @@ -465,6 +450,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/2, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_2x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -485,6 +471,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/3, /*nr=*/16, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_3x16__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, @@ -505,6 +492,7 @@ std::vector CreateTests1( /*adj_k_block=*/8, /*mr=*/4, /*nr=*/8, /*kr=*/1, /*sr=*/1, /*is_igemm=*/true, + /*unsigned_inputs=*/false, [](GemmMicrokernelTester& tester) { tester.Test(xnn_qu8_igemm_minmax_rndnu_ukernel_4x8__neon_mlal_lane, xnn_init_qu8_conv_minmax_rndnu_neon_params, diff --git a/test/raddextexp-microkernel-tester.h b/test/raddextexp-microkernel-tester.h deleted file mode 100644 index e7f875e548e0..000000000000 --- a/test/raddextexp-microkernel-tester.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include "xnnpack.h" -#include "xnnpack/microfnptr.h" -#include "xnnpack/buffer.h" -#include "replicable_random_device.h" - -class RAddExtExpMicrokernelTester { - public: - RAddExtExpMicrokernelTester& elements(size_t elements) { - assert(elements != 0); - this->elements_ = elements; - return *this; - } - - size_t elements() const { - return this->elements_; - } - - RAddExtExpMicrokernelTester& iterations(size_t iterations) { - this->iterations_ = iterations; - return *this; - } - - size_t iterations() const { - return this->iterations_; - } - - void Test(xnn_f32_raddextexp_ukernel_fn raddextexp) const { - xnnpack::ReplicableRandomDevice rng; - // Choose such range that expf(x[i]) overflows, but double-precision exp doesn't overflow. - auto f32rng = [&rng]() { - return std::uniform_real_distribution(90.0f, 100.0f)(rng); - }; - - xnnpack::Buffer x(elements() + XNN_EXTRA_BYTES / sizeof(float)); - for (size_t iteration = 0; iteration < iterations(); iteration++) { - std::generate(x.begin(), x.end(), std::ref(f32rng)); - - // Compute reference results. - double sum_ref = 0.0f; - for (size_t i = 0; i < elements(); i++) { - sum_ref += exp(double(x[i])); - } - - // Call optimized micro-kernel. - float sum[2]; - raddextexp(elements() * sizeof(float), x.data(), sum); - - // Verify results. - ASSERT_NEAR(sum_ref, exp2(double(sum[1])) * double(sum[0]), std::abs(sum_ref) * 1.0e-6) - << "elements = " << elements() << ", y:value = " << sum[0] << ", y:exponent = " << sum[1]; - } - } - - private: - size_t elements_{1}; - size_t iterations_{15}; -}; diff --git a/test/reshape-helpers.cc b/test/reshape-helpers.cc index 1c3ad0c02929..7857106f4025 100644 --- a/test/reshape-helpers.cc +++ b/test/reshape-helpers.cc @@ -21,6 +21,7 @@ #include "xnnpack/subgraph.h" #include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "runtime-flags.h" xnn_runtime_t SetupUnary(const std::vector &dims) { if (xnn_initialize(/*allocator=*/nullptr) != xnn_status_success) { @@ -63,7 +64,7 @@ xnn_runtime_t SetupUnary(const std::vector &dims) { } xnn_runtime_t runtime = nullptr; - if (xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, + if (xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime) != xnn_status_success) { return nullptr; } @@ -108,7 +109,7 @@ xnn_runtime_t SetupBinary(const std::vector &input0_dims, uint32_t output_id = XNN_INVALID_NODE_ID; if (xnn_define_tensor_value(subgraph, xnn_datatype_fp32, 0, input1_dims.data(), nullptr, 2, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id) != xnn_status_success) { return nullptr; } @@ -124,7 +125,7 @@ xnn_runtime_t SetupBinary(const std::vector &input0_dims, } xnn_runtime_t runtime = nullptr; - if (xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, + if (xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime) != xnn_status_success) { return nullptr; } diff --git a/test/rope.cc b/test/rope.cc index cf86afc4e5d5..3ddac875100b 100644 --- a/test/rope.cc +++ b/test/rope.cc @@ -20,6 +20,7 @@ #include "xnnpack/subgraph.h" #include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class RoPETestBase : public ::testing::Test { protected: @@ -207,7 +208,7 @@ TEST_F(RoPETestF16, matches_operator_api) xnn_define_rope(subgraph, max_tokens, input_id, weights_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -285,7 +286,7 @@ TEST_F(RoPETestF32, matches_operator_api) xnn_define_rope(subgraph, max_tokens, input_id, weights_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -337,7 +338,7 @@ TEST_F(RoPETestF32, reshape_output) xnn_define_rope(subgraph, max_tokens, input_id, weights_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/runtime-flags.cc b/test/runtime-flags.cc new file mode 100644 index 000000000000..cf2c91ea453a --- /dev/null +++ b/test/runtime-flags.cc @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "runtime-flags.h" + +#include +#if GTEST_HAS_ABSL +#include +#endif + +// We only define this if GTest has a separate Abseil install, so we +// can use Abseil's built-in command-line-flag processing. +#if GTEST_HAS_ABSL +ABSL_FLAG(uint32_t, xnn_runtime_flags, 0, + "Value to pass to xnn_create_runtime for flags"); +#endif + +extern "C" { + +uint32_t xnn_test_runtime_flags() { +#if GTEST_HAS_ABSL + return absl::GetFlag(FLAGS_xnn_runtime_flags); +#else + return 0; +#endif +} + +} diff --git a/test/runtime-flags.h b/test/runtime-flags.h new file mode 100644 index 000000000000..7a75b52cfe60 --- /dev/null +++ b/test/runtime-flags.h @@ -0,0 +1,11 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +extern "C" uint32_t xnn_test_runtime_flags(); + diff --git a/test/runtime-tester.h b/test/runtime-tester.h index ee4854a64c68..c71dc9b97104 100644 --- a/test/runtime-tester.h +++ b/test/runtime-tester.h @@ -11,12 +11,14 @@ #include #include #include +#include #include #include #include "xnnpack.h" #include "xnnpack/subgraph.h" #include "subgraph-tester.h" +#include "runtime-flags.h" namespace xnnpack { @@ -35,7 +37,7 @@ class RuntimeTester : public SubgraphTester { template xnnpack::Buffer RunWithoutFusion() { - Run(XNN_FLAG_NO_OPERATOR_FUSION); + Run(XNN_FLAG_NO_OPERATOR_FUSION | xnn_test_runtime_flags()); xnnpack::Buffer& tensor = this->external_tensors_.at(this->output_id_); xnnpack::Buffer output = xnnpack::Buffer(tensor.size() / sizeof(float)); memcpy(output.data(), tensor.data(), tensor.size()); @@ -51,7 +53,7 @@ class RuntimeTester : public SubgraphTester { return output; } - void CreateRuntime(uint32_t flags = 0) { + void CreateRuntime(uint32_t flags) { xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(this->subgraph_.get(), nullptr, nullptr, flags, &runtime)); ASSERT_NE(nullptr, runtime); @@ -121,7 +123,7 @@ class RuntimeTester : public SubgraphTester { } private: - void Run(uint32_t flags = 0) { + void Run(uint32_t flags = xnn_test_runtime_flags()) { CreateRuntime(flags); SetupRuntime(); diff --git a/test/s32-f32-vcvt.cc b/test/s32-f32-vcvt.cc deleted file mode 100644 index 8f08f5f11c06..000000000000 --- a/test/s32-f32-vcvt.cc +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - - -#include "xnnpack/microparams-init.h" -#include "xnnpack/vcvt.h" -#include "vunary-microkernel-tester.h" - -#define XNN_QUANTIZED(T) xnnpack::quantized -#define XNN_CVT_UKERNEL_WITH_PARAMS(arch_flags, ukernel, batch_tile, vector_tile, \ - datatype_in, datatype_out, params_type, init_params) \ - TEST(ukernel, batch_eq) { TestBatchEq(arch_flags, batch_tile, ukernel, init_params); } \ - TEST(ukernel, batch_div) { TestBatchDiv(arch_flags, batch_tile, ukernel, init_params); } \ - TEST(ukernel, batch_lt) { TestBatchLT(arch_flags, batch_tile, ukernel, init_params); } \ - TEST(ukernel, batch_gt) { TestBatchGT(arch_flags, batch_tile, ukernel, init_params); } \ - TEST(ukernel, input_zero_point) { TestInputZeroPoint(arch_flags, batch_tile, ukernel, init_params); } -#include "s32-f32-vcvt/s32-f32-vcvt.h" -#undef XNN_CVT_UKERNEL_WITH_PARAMS -#undef XNN_QUANTIZED diff --git a/test/scaled-dot-product-attention.cc b/test/scaled-dot-product-attention.cc index 85fcd273c95d..9b23faf9bd8d 100644 --- a/test/scaled-dot-product-attention.cc +++ b/test/scaled-dot-product-attention.cc @@ -24,6 +24,7 @@ #include "xnnpack/node-type.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class ScaledDotProductAttentionTestBase : public ::testing::Test { @@ -437,7 +438,7 @@ TEST_F(ScaledDotProductAttentionTestF16, matches_operator_api) { subgraph, cap_type, &cap_params, query_id, key_id, value_id, scale_id, mask_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -554,7 +555,7 @@ TEST_F(ScaledDotProductAttentionTestF32, matches_operator_api) { subgraph, cap_type, &cap_params, query_id, key_id, value_id, scale_id, mask_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -721,7 +722,7 @@ TEST_F(ScaledDotProductAttentionTestF32, matches_operator_api_dynamic_shape_no_r subgraph, cap_type, &cap_params, query_id, key_id, value_id, scale_id, mask_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -922,7 +923,7 @@ TEST_F(ScaledDotProductAttentionTestF32, matches_operator_api_dynamic_shape_requ subgraph, cap_type, &cap_params, query_id, key_id, value_id, scale_id, mask_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/softmax.cc b/test/softmax.cc index ae4f40b00f23..fe9c449cc658 100644 --- a/test/softmax.cc +++ b/test/softmax.cc @@ -18,6 +18,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" using SoftmaxTestF16 = UnaryTest; @@ -138,7 +139,7 @@ TEST_F(SoftmaxTestF16, matches_operator_api) xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_define_softmax(subgraph, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -197,7 +198,7 @@ TEST_F(SoftmaxTestF32, matches_operator_api) xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_define_softmax(subgraph, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -240,7 +241,7 @@ TEST_F(SoftmaxTestF32, reshape_output) ASSERT_EQ(xnn_status_success, xnn_define_softmax(subgraph, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/space-to-depth-2d.cc b/test/space-to-depth-2d.cc index 185def542819..4bda5ec61345 100644 --- a/test/space-to-depth-2d.cc +++ b/test/space-to-depth-2d.cc @@ -19,6 +19,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" using SpaceToDepth2DTestQS8 = UnaryTest; using SpaceToDepth2DTestQU8 = UnaryTest; @@ -268,7 +269,7 @@ TEST_F(SpaceToDepth2DTestQS8, matches_operator_api) xnn_define_space_to_depth_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -345,7 +346,7 @@ TEST_F(SpaceToDepth2DTestQU8, matches_operator_api) xnn_define_space_to_depth_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -416,7 +417,7 @@ TEST_F(SpaceToDepth2DTestF16, matches_operator_api) xnn_define_space_to_depth_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -487,7 +488,7 @@ TEST_F(SpaceToDepth2DTestF32, matches_operator_api) xnn_define_space_to_depth_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -531,7 +532,7 @@ TEST_F(SpaceToDepth2DTestF32, reshape_output) xnn_define_space_to_depth_2d(subgraph, block_size, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/static-constant-pad.cc b/test/static-constant-pad.cc index 3f92e5000943..5a39bd1c4dbd 100644 --- a/test/static-constant-pad.cc +++ b/test/static-constant-pad.cc @@ -21,6 +21,7 @@ #include "xnnpack/requantization.h" #include "xnnpack/subgraph.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" using StaticConstantPadTestInt8 = UnaryTest; using StaticConstantPadTestUint8 = UnaryTest; @@ -294,7 +295,7 @@ TEST_F(StaticConstantPadTestInt8, matches_operator_api) subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -369,7 +370,7 @@ TEST_F(StaticConstantPadTestUint8, matches_operator_api) subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -442,7 +443,7 @@ TEST_F(StaticConstantPadTestF16, matches_operator_api) subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -515,7 +516,7 @@ TEST_F(StaticConstantPadTestF32, matches_operator_api) subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -569,7 +570,7 @@ TEST_F(StaticConstantPadTestF32, reshape_output) subgraph, pre_paddings.data(), post_paddings.data(), padding_value, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/static-expand-dims.cc b/test/static-expand-dims.cc index 14d68bc8bf63..c5e102b715c0 100644 --- a/test/static-expand-dims.cc +++ b/test/static-expand-dims.cc @@ -22,6 +22,7 @@ #include "xnnpack/subgraph.h" #include "replicable_random_device.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" template auto_runtime(runtime, xnn_delete_runtime); @@ -247,7 +248,7 @@ TEST_F(StaticExpandDimsTestF16, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_static_expand_dims(subgraph, new_axes.size(), new_axes.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/static-reduce.cc b/test/static-reduce.cc index 44664d9f4eee..a871cac743c2 100644 --- a/test/static-reduce.cc +++ b/test/static-reduce.cc @@ -28,6 +28,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" struct Param { using TupleT = std::tuple; @@ -421,7 +422,7 @@ TEST_P(ReduceTest, matches_operator_api) { xnn_runtime_t runtime = nullptr; ASSERT_EQ( xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime( runtime, xnn_delete_runtime); @@ -493,7 +494,7 @@ TEST_P(ReduceTest, reshape) { xnn_runtime_t runtime = nullptr; xnn_status status = - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } diff --git a/test/static-reshape.cc b/test/static-reshape.cc index 07f38face5e8..948d59faaf0c 100644 --- a/test/static-reshape.cc +++ b/test/static-reshape.cc @@ -22,6 +22,7 @@ #include "xnnpack/subgraph.h" #include "replicable_random_device.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" template auto_runtime(runtime, xnn_delete_runtime); @@ -321,7 +322,7 @@ TEST_F(StaticReshapeTestUint8, matches_operator_api) ASSERT_EQ(xnn_status_success, xnn_define_static_reshape(subgraph, new_dims_hint.size(), new_dims_hint.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -378,7 +379,7 @@ TEST_F(StaticReshapeTestF16, matches_operator_api) xnn_status_success, xnn_define_static_reshape(subgraph, output_dims.size(), output_dims.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -435,7 +436,7 @@ TEST_F(StaticReshapeTestF32, matches_operator_api) xnn_status_success, xnn_define_static_reshape(subgraph, new_dims_hint.size(), new_dims_hint.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -482,7 +483,7 @@ TEST_F(StaticReshapeTestF32, reshape_output) { /*flags=*/0)); ASSERT_EQ( xnn_status_success, - xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime( runtime, xnn_delete_runtime); diff --git a/test/static-resize-bilinear-2d.cc b/test/static-resize-bilinear-2d.cc index d2f6f850e991..0c516a09927d 100644 --- a/test/static-resize-bilinear-2d.cc +++ b/test/static-resize-bilinear-2d.cc @@ -21,6 +21,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class StaticResizeBilinear2DTestBase : public ::testing::Test { protected: @@ -291,7 +292,7 @@ TEST_F(StaticResizeBilinear2DTestQS8, matches_operator_api) xnn_define_static_resize_bilinear_2d(subgraph, output_height, output_width, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -358,7 +359,7 @@ TEST_F(StaticResizeBilinear2DTestQU8, matches_operator_api) xnn_define_static_resize_bilinear_2d(subgraph, output_height, output_width, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -421,7 +422,7 @@ TEST_F(StaticResizeBilinear2DTestF16, matches_operator_api) xnn_define_static_resize_bilinear_2d(subgraph, output_height, output_width, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -484,7 +485,7 @@ TEST_F(StaticResizeBilinear2DTestF32, matches_operator_api) xnn_define_static_resize_bilinear_2d(subgraph, output_height, output_width, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -523,7 +524,7 @@ TEST_F(StaticResizeBilinear2DTestF32, reshape_output) xnn_define_static_resize_bilinear_2d(subgraph, output_height, output_width, input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/static-slice.cc b/test/static-slice.cc index 43c71b943431..f3e4945b0a0b 100644 --- a/test/static-slice.cc +++ b/test/static-slice.cc @@ -22,6 +22,7 @@ #include "xnnpack/subgraph.h" #include "replicable_random_device.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" template class StaticSliceTest : public UnaryTest { public: @@ -286,7 +287,7 @@ TEST_F(StaticSliceTestQS8, matches_operator_api) xnn_status_success, xnn_define_static_slice(subgraph, dims.size(), offsets.data(), inferrable_sizes.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -345,7 +346,7 @@ TEST_F(StaticSliceTestQU8, matches_operator_api) xnn_status_success, xnn_define_static_slice(subgraph, dims.size(), offsets.data(), inferrable_sizes.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -400,7 +401,7 @@ TEST_F(StaticSliceTestF16, matches_operator_api) xnn_status_success, xnn_define_static_slice(subgraph, dims.size(), offsets.data(), sizes.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -455,7 +456,7 @@ TEST_F(StaticSliceTestF32, matches_operator_api) xnn_status_success, xnn_define_static_slice(subgraph, dims.size(), offsets.data(), inferrable_sizes.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -493,7 +494,7 @@ TEST_F(StaticSliceTestF32, reshape_output) xnn_status_success, xnn_define_static_slice(subgraph, dims.size(), offsets.data(), inferrable_sizes.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/static-transpose.cc b/test/static-transpose.cc index e93da0052aff..63ffd28304a0 100644 --- a/test/static-transpose.cc +++ b/test/static-transpose.cc @@ -20,6 +20,7 @@ #include "xnnpack/operator.h" #include "xnnpack/subgraph.h" #include "subgraph-unary-tester.h" +#include "runtime-flags.h" using StaticTransposeTestQS8 = UnaryTest; using StaticTransposeTestQU8 = UnaryTest; @@ -273,7 +274,7 @@ TEST_F(StaticTransposeTestQS8, matches_operator_api) xnn_define_static_transpose(subgraph, perm.size(), perm.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -335,7 +336,7 @@ TEST_F(StaticTransposeTestQU8, matches_operator_api) xnn_define_static_transpose(subgraph, perm.size(), perm.data(), input_id, output_id, /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -395,7 +396,7 @@ TEST_F(StaticTransposeTestF16, matches_operator_api) ASSERT_EQ( xnn_status_success, xnn_define_static_transpose(subgraph, perm.size(), perm.data(), input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -454,7 +455,7 @@ TEST_F(StaticTransposeTestF32, matches_operator_api) ASSERT_EQ( xnn_status_success, xnn_define_static_transpose(subgraph, perm.size(), perm.data(), input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { diff --git a/test/subgraph-fp16.cc b/test/subgraph-fp16.cc index 01beba54720e..230ae8fb80f3 100644 --- a/test/subgraph-fp16.cc +++ b/test/subgraph-fp16.cc @@ -17,11 +17,13 @@ #include #include "xnnpack.h" #include "xnnpack/allocation-type.h" +#include "xnnpack/buffer.h" #include "xnnpack/math.h" #include "xnnpack/node-type.h" #include "xnnpack/subgraph.h" #include "mock-allocator.h" #include "replicable_random_device.h" +#include "runtime-flags.h" #include "runtime-tester.h" #include "subgraph-tester.h" @@ -368,7 +370,7 @@ TEST(SUBGRAPH_FP16, external_inputs_allocation_type_remains_external) { xnn_runtime_t runtime = tester.Runtime(); xnn_status status = xnn_create_runtime_v3(tester.Subgraph(), nullptr, nullptr, - /*flags=*/0, &runtime); + xnn_test_runtime_flags(), &runtime); std::unique_ptr auto_runtime( runtime, xnn_delete_runtime); @@ -644,7 +646,9 @@ TEST(SUBGRAPH_FP16, fully_connected_qd8_f16_qc8w) { ASSERT_EQ(tester.NumNodes(), 4); xnn_runtime_t fp16_runtime_ptr = nullptr; - xnn_status status = xnn_create_runtime(tester.Subgraph(), &fp16_runtime_ptr); + xnn_status status = + xnn_create_runtime_v2(tester.Subgraph(), /*threadpool*/ nullptr, + xnn_test_runtime_flags(), &fp16_runtime_ptr); if (status == xnn_status_unsupported_hardware) { GTEST_SKIP(); } @@ -652,7 +656,9 @@ TEST(SUBGRAPH_FP16, fully_connected_qd8_f16_qc8w) { fp16_runtime_ptr, xnn_delete_runtime); ASSERT_EQ(xnn_status_success, status); xnn_runtime_t fp32_runtime_ptr = nullptr; - status = xnn_create_runtime(reference_tester.Subgraph(), &fp32_runtime_ptr); + status = + xnn_create_runtime_v2(reference_tester.Subgraph(), /*threadpool*/ nullptr, + xnn_test_runtime_flags(), &fp32_runtime_ptr); ASSERT_EQ(xnn_status_success, status); std::unique_ptr auto_fp32_runtime( fp32_runtime_ptr, xnn_delete_runtime); diff --git a/test/transpose-reshape.cc b/test/transpose-reshape.cc index 718e9fa25524..ebe2dcf71f90 100644 --- a/test/transpose-reshape.cc +++ b/test/transpose-reshape.cc @@ -14,6 +14,7 @@ #include "xnnpack.h" #include "xnnpack/node-type.h" #include "xnnpack/subgraph.h" +#include "runtime-flags.h" TEST(TransposeTestF32, Reshape) { @@ -51,7 +52,7 @@ TEST(TransposeTestF32, Reshape) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/unary-ops.cc b/test/unary-ops.cc index 76205713cab4..4b5554a26843 100644 --- a/test/unary-ops.cc +++ b/test/unary-ops.cc @@ -88,4 +88,5 @@ const UnaryOpInfo* GetUnaryOpInfo(xnn_unary_operator op) { case xnn_unary_invalid: return nullptr; } + return nullptr; } \ No newline at end of file diff --git a/test/unary-ops.h b/test/unary-ops.h index ff1cbf7b37c4..b7e00a00aa1c 100644 --- a/test/unary-ops.h +++ b/test/unary-ops.h @@ -6,8 +6,6 @@ #ifndef THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_ #define THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_ -#pragma once - #include #include #include @@ -20,11 +18,17 @@ #include "xnnpack.h" #include "xnnpack/buffer.h" +#include "xnnpack/common.h" +#include "xnnpack/datatype.h" #include "xnnpack/math.h" #include "xnnpack/reference-utils.h" static float TolExact(float) { return 0.0f; } -static float TolExact16(float y_ref) { return std::abs(y_ref) * 1.0e-3f; } +static float TolExact16(float y_ref) { + // The maximum of the relative tolerance and half the smallest positive + // normal. + return std::max(std::abs(y_ref) * 9.77e-04, 0.5 * 6.10e-5); +} static float TolRelative(float y_ref, float rel_tol) { // Note that `y_ref * rel_tol`, i.e. the expected absolute difference, @@ -214,10 +218,12 @@ struct GELU : public UnaryOpInfo { float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: - case xnn_datatype_fp16: - case xnn_datatype_bf16: return TolMixed(y_ref, 10 * std::numeric_limits::epsilon(), 5 * std::numeric_limits::epsilon()); + case xnn_datatype_fp16: + return TolMixed(y_ref, 10 * 9.77e-04, 5 * 9.77e-04); + case xnn_datatype_bf16: + return TolMixed(y_ref, 10 * 7.8125e-3, 5 * 7.8125e-3); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; @@ -599,7 +605,8 @@ void UnaryReferenceImpl( const xnn_quantization_params& output_quantization = {0, 1.0f}, const xnn_unary_params& params = xnn_unary_params()) { for (size_t i = 0; i < n; i++) { - float x_i = (x[i] - input_quantization.zero_point) * input_quantization.scale; + float x_i = + (x[i] - input_quantization.zero_point) * input_quantization.scale; float y_i = op_info.ReferenceImpl(x_i, params); y_i = y_i / output_quantization.scale + output_quantization.zero_point; y[i] = xnnpack::round_float_to_int(y_i); @@ -630,7 +637,8 @@ void UnaryReferenceImpl( const xnn_unary_params& params = xnn_unary_params()) { static_assert(!xnnpack::is_quantized::value, ""); for (size_t i = 0; i < n; i++) { - float x_i = (x[i] - input_quantization.zero_point) * input_quantization.scale; + float x_i = + (x[i] - input_quantization.zero_point) * input_quantization.scale; float y_i = op_info.ReferenceImpl(x_i, params); if (std::is_integral::value) { y[i] = xnnpack::round_float_to_int(y_i); diff --git a/test/unary.cc b/test/unary.cc index 3d67e3767298..7d5732e52aea 100644 --- a/test/unary.cc +++ b/test/unary.cc @@ -28,6 +28,7 @@ #include "xnnpack/subgraph.h" #include "replicable_random_device.h" #include "unary-ops.h" +#include "runtime-flags.h" struct Param { using UnaryT = std::tuple; @@ -152,7 +153,7 @@ TEST_P(UnaryTest, matches_operator_api) { ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, unary_operator, ¶ms, input_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -250,7 +251,7 @@ TEST(AbsTest, reshape) { ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/unpooling-2d.cc b/test/unpooling-2d.cc index 8b5be251cee1..e753ce15e8f0 100644 --- a/test/unpooling-2d.cc +++ b/test/unpooling-2d.cc @@ -19,6 +19,7 @@ #include "xnnpack/subgraph.h" #include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "runtime-flags.h" template class Unpooling2DTestBase : public ::testing::Test { protected: @@ -206,7 +207,7 @@ TEST_F(Unpooling2DTestX32, matches_operator_api) /*flags=*/0)); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -263,7 +264,7 @@ TEST_F(Unpooling2DTestX32, reshape_output) ASSERT_EQ(node->flags, 0); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); diff --git a/test/workspace.cc b/test/workspace.cc index bad27879a7e4..6c0040fe9020 100644 --- a/test/workspace.cc +++ b/test/workspace.cc @@ -5,7 +5,6 @@ #include #include -#include #include #include #include @@ -17,10 +16,12 @@ #include #include "xnnpack.h" #include "xnnpack/allocation-type.h" +#include "xnnpack/common.h" #include "xnnpack/math.h" #include "xnnpack/subgraph.h" #include "xnnpack/buffer.h" #include "replicable_random_device.h" +#include "runtime-flags.h" namespace { void DefineGraphWithoutInternalTensors(xnn_subgraph_t* subgraph, std::array dims) @@ -175,7 +176,7 @@ TEST(WORKSPACE, static_data_not_moved_does_not_segv) DefineGraphWithStaticData(&subgraph1, dims, &static_data); std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); const std::array external_values1 = { xnn_external_value{0, static_data.data()}, xnn_external_value{2, static_data.data()}, @@ -194,7 +195,7 @@ TEST(WORKSPACE, static_data_not_moved_does_not_segv) DefineGraph(&subgraph2, dims); std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); const std::array external_values2 = { xnn_external_value{0, static_data.data()}, xnn_external_value{2, static_data.data()}, @@ -250,7 +251,7 @@ TEST(WORKSPACE, workspace_no_growth) std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime1, 2, external_values.data())); std::unique_ptr auto_runtime1(runtime1, xnn_delete_runtime); @@ -265,7 +266,7 @@ TEST(WORKSPACE, workspace_no_growth) std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime2, 2, external_values.data())); std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); @@ -319,7 +320,7 @@ TEST(WORKSPACE, workspace_grow) std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); // No workspace allocated yet, it should be only allocated on setup. ASSERT_EQ(workspace->size, 0); @@ -341,7 +342,7 @@ TEST(WORKSPACE, workspace_grow) std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime2, 2, external_values2.data())); std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); @@ -398,7 +399,7 @@ TEST(WORKSPACE, workspace_runtime_delete_head_runtime_first) std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime1, 2, external_values.data())); std::unique_ptr auto_runtime1(runtime1, xnn_delete_runtime); @@ -407,7 +408,7 @@ TEST(WORKSPACE, workspace_runtime_delete_head_runtime_first) std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime2, 2, external_values.data())); std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); @@ -450,7 +451,7 @@ TEST(WORKSPACE, workspace_runtime_delete_tail_runtime_first) std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime1, 2, external_values.data())); std::unique_ptr auto_runtime1(runtime1, xnn_delete_runtime); @@ -459,7 +460,7 @@ TEST(WORKSPACE, workspace_runtime_delete_tail_runtime_first) std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime2, 2, external_values.data())); std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); @@ -503,7 +504,7 @@ TEST(WORKSPACE, workspace_runtime_delete_middle_runtime_first) std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime1, 2, external_values.data())); std::unique_ptr auto_runtime1(runtime1, xnn_delete_runtime); @@ -512,7 +513,7 @@ TEST(WORKSPACE, workspace_runtime_delete_middle_runtime_first) std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime2, 2, external_values.data())); std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); @@ -521,7 +522,7 @@ TEST(WORKSPACE, workspace_runtime_delete_middle_runtime_first) std::unique_ptr auto_subgraph3(subgraph3, xnn_delete_subgraph); xnn_runtime_t runtime3 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph3, nullptr, workspace, nullptr, 0, &runtime3)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph3, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime3)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime3, 2, external_values.data())); std::unique_ptr auto_runtime3(runtime3, xnn_delete_runtime); @@ -576,7 +577,7 @@ TEST(WORKSPACE, zero_sized_workspace_for_graph_without_internal_tensors) std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph, nullptr, workspace, nullptr, 0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, 2, external_values.data())); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -609,7 +610,7 @@ TEST(WORKSPACE, persistent_tensors_allocated_at_start_of_workspace) const std::unique_ptr auto_subgraph(subgraph, xnn_delete_subgraph); xnn_runtime_t runtime = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph, nullptr, workspace, nullptr, 0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, 2, external_values.data())); const std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); @@ -656,7 +657,7 @@ TEST(WORKSPACE, persistent_tensors_updated_correct_when_workspace_grows) std::unique_ptr auto_subgraph1(subgraph1, xnn_delete_subgraph); xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime1, 2, external_values.data())); const std::unique_ptr auto_runtime(runtime1, xnn_delete_runtime); @@ -672,7 +673,7 @@ TEST(WORKSPACE, persistent_tensors_updated_correct_when_workspace_grows) DefineGraphWithPersistentTensors(&subgraph2, dims2); const std::unique_ptr auto_subgraph2(subgraph2, xnn_delete_subgraph); xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime2, 2, external_values2.data())); const std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); @@ -766,7 +767,7 @@ TEST(WORKSPACE, persistent_tensors_values_copied_when_workspace_grows) } xnn_runtime_t runtime1 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime1)); const std::unique_ptr auto_runtime(runtime1, xnn_delete_runtime); xnnpack::Buffer expected(2 * 2 * 2 * 3 + XNN_EXTRA_BYTES / sizeof(float), 3.14f); const std::array external_values = { @@ -776,7 +777,7 @@ TEST(WORKSPACE, persistent_tensors_values_copied_when_workspace_grows) // Create the same graph but with larger tensors, this will require a larger workspace. xnn_runtime_t runtime2 = nullptr; - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, xnn_test_runtime_flags(), &runtime2)); const std::unique_ptr auto_runtime2(runtime2, xnn_delete_runtime); const size_t old_workspace_size = workspace->size; @@ -863,7 +864,7 @@ TEST(WORKSPACE, internally_allocated_dynamic_quantization_parameters) xnn_runtime_t runtime = nullptr; ASSERT_EQ(xnn_status_success, xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, input_id, dq_quantized_id, /*flags=*/0)); ASSERT_EQ(xnn_status_success, xnn_define_fully_connected(subgraph, output_min, output_max, dq_quantized_id, kernel_id, bias_id, output_id, /*flags=*/0)); - ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime)); + ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, xnn_test_runtime_flags(), &runtime)); ASSERT_NE(nullptr, runtime); std::unique_ptr auto_runtime(runtime, xnn_delete_runtime); std::array external = { @@ -873,9 +874,15 @@ TEST(WORKSPACE, internally_allocated_dynamic_quantization_parameters) size_t dq_tensors = 0; for (size_t i = 0; i < runtime->num_values; i++) { const xnn_value* value = &runtime->values[i]; - if (value->datatype == xnn_datatype_qdint8) { - ++dq_tensors; - ASSERT_NE(value->quantization.dynamic_params, nullptr); + switch (value->datatype) { + case xnn_datatype_qdint8: + ASSERT_NE(value->quantization.dynamic_params, nullptr); + XNN_FALLTHROUGH; + case xnn_datatype_qpint8: + ++dq_tensors; + break; + default: + break; } } ASSERT_EQ(dq_tensors, 1); diff --git a/tools/generate-gemm-test.py b/tools/generate-gemm-test.py index c2c00e6b3167..bd7ffd8794b0 100755 --- a/tools/generate-gemm-test.py +++ b/tools/generate-gemm-test.py @@ -103,6 +103,7 @@ def split_ukernel_name(name): $if DATATYPE in ('qp8'): size_t mr_packed, bool is_igemm, + bool unsigned_inputs, std::function test_func, std::function isa_check = nullptr) { std::string kbs = std::to_string(k_block); @@ -114,10 +115,10 @@ def split_ukernel_name(name): $if DATATYPE in ('qp8',): const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed); + .mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed).unsigned_inputs(unsigned_inputs); $else: const GemmMicrokernelTester tester = GemmMicrokernelTester() - .mr(mr).nr(nr).kr(kr).sr(sr); + .mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs); std::vector gemm_tests; gemm_tests.reserve(42); @@ -132,28 +133,18 @@ def split_ukernel_name(name): .bl(32) , test_func, isa_check)); $if DATATYPE != "qp8": - gemm_tests.push_back(GemmTestParams( - "strided_cn", - tester.clone() - .m(mr).n(nr).k(k_block) - .cn_stride(xnnpack::NextPrime(nr + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_eq_" + kbs + "_strided_a", - tester.clone() - .m(mr).n(nr).k(k_block) - .a_stride(xnnpack::NextPrime(k_block + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check)); - } + if (!is_igemm) { + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kbs + "_strided_a", + tester.clone() + .m(mr).n(nr).k(k_block) + .a_stride(xnnpack::NextPrime(k_block + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check)); + } gemm_tests.push_back(GemmTestParams( "k_eq_" + kbs + "_subtile", tester.clone() @@ -195,18 +186,19 @@ def split_ukernel_name(name): $if KERNELTYPE in ['qb4w']: .bl(32) , test_func, isa_check)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_eq_" + kb2s + "_strided_a", - tester.clone() - .m(mr).n(nr).k(k_block * 2) - .a_stride(xnnpack::NextPrime(k_block * 2 + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check)); - } + $if DATATYPE != "qp8": + if (!is_igemm) { + gemm_tests.push_back(GemmTestParams( + "k_eq_" + kb2s + "_strided_a", + tester.clone() + .m(mr).n(nr).k(k_block * 2) + .a_stride(xnnpack::NextPrime(k_block * 2 + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check)); + } gemm_tests.push_back(GemmTestParams( "k_eq_" + kb2s + "_subtile", tester.clone() @@ -230,19 +222,20 @@ def split_ukernel_name(name): .bl(32) , test_func, isa_check) .loop_k(1, adj_k_block - 1)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_lt_" + akbs + "_strided_a", - tester.clone() - .m(mr).n(nr) - .a_stride(xnnpack::NextPrime(adj_k_block + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - .loop_k(1, adj_k_block - 1)); - } + $if DATATYPE != "qp8": + if (!is_igemm) { + gemm_tests.push_back(GemmTestParams( + "k_lt_" + akbs + "_strided_a", + tester.clone() + .m(mr).n(nr) + .a_stride(xnnpack::NextPrime(adj_k_block + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check) + .loop_k(1, adj_k_block - 1)); + } gemm_tests.push_back(GemmTestParams( "k_lt_" + akbs + "_subtile", tester.clone() @@ -266,19 +259,20 @@ def split_ukernel_name(name): .bl(32) , test_func, isa_check) .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); - if (is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_gt_" + akbs + "_strided_a", - tester.clone() - .m(mr).n(nr) - .a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); - } + $if DATATYPE != "qp8": + if (is_igemm) { + gemm_tests.push_back(GemmTestParams( + "k_gt_" + akbs + "_strided_a", + tester.clone() + .m(mr).n(nr) + .a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check) + .loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)); + } gemm_tests.push_back(GemmTestParams( "k_gt_" + akbs + "_subtile", tester.clone() @@ -302,19 +296,20 @@ def split_ukernel_name(name): .bl(32) , test_func, isa_check) .loop_k(adj_k_block + k_block, k_block * 5, k_block)); - if (is_igemm) { - gemm_tests.push_back(GemmTestParams( - "k_div_" + kbs + "_strided_a", - tester.clone() - .m(mr).n(nr) - .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - .loop_k(adj_k_block + k_block, k_block * 3, k_block)); - } + $if DATATYPE != "qp8": + if (is_igemm) { + gemm_tests.push_back(GemmTestParams( + "k_div_" + kbs + "_strided_a", + tester.clone() + .m(mr).n(nr) + .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check) + .loop_k(adj_k_block + k_block, k_block * 3, k_block)); + } gemm_tests.push_back(GemmTestParams( "k_div_" + kbs + "_subtile", tester.clone() @@ -343,40 +338,23 @@ def split_ukernel_name(name): .loop_n(nr + 1, nr * 2 - 1) .loop_k(1, k_block * 3, k_block + 1)); $if DATATYPE != "qp8": - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - $if NR_SCALE != "": - .loop_n(nr + 1, nr * 2 - 1, 4) - $else: - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block + 1)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "n_gt_" + nrs + "_strided_a", - tester.clone() - .m(mr) - .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - $if NR_SCALE != "": - .loop_n(nr + 1, nr * 2 - 1, 4) - $else: - .loop_n(nr + 1, nr * 2 - 1) - .loop_k(1, k_block * 3, k_block)); - } + if (!is_igemm) { + gemm_tests.push_back(GemmTestParams( + "n_gt_" + nrs + "_strided_a", + tester.clone() + .m(mr) + .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check) + $if NR_SCALE != "": + .loop_n(nr + 1, nr * 2 - 1, 4) + $else: + .loop_n(nr + 1, nr * 2 - 1) + .loop_k(1, k_block * 3, k_block)); + } gemm_tests.push_back(GemmTestParams( "n_gt_" + nrs + "_subtile", tester.clone() @@ -404,32 +382,20 @@ def split_ukernel_name(name): .loop_n(nr * 2, nr * 3, nr) .loop_k(1, k_block * 3, k_block + 1)); $if DATATYPE != "qp8": - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_cn", - tester.clone() - .m(mr) - .cn_stride(xnnpack::NextPrime(nr + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block + 1)); - if (!is_igemm) { - gemm_tests.push_back(GemmTestParams( - "n_div_" + nrs + "_strided_a", - tester.clone() - .m(mr) - .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) - $if KERNELTYPE in ['qb4w', 'qc4w']: - .b_zero_point(8) - $if KERNELTYPE in ['qb4w']: - .bl(32) - , test_func, isa_check) - .loop_n(nr * 2, nr * 3, nr) - .loop_k(1, k_block * 3, k_block)); - } + if (!is_igemm) { + gemm_tests.push_back(GemmTestParams( + "n_div_" + nrs + "_strided_a", + tester.clone() + .m(mr) + .a_stride(xnnpack::NextPrime(k_block * 3 + 1)) + $if KERNELTYPE in ['qb4w', 'qc4w']: + .b_zero_point(8) + $if KERNELTYPE in ['qb4w']: + .bl(32) + , test_func, isa_check) + .loop_n(nr * 2, nr * 3, nr) + .loop_k(1, k_block * 3, k_block)); + } gemm_tests.push_back(GemmTestParams( "n_div_" + nrs + "_subtile", tester.clone() @@ -613,8 +579,9 @@ def split_ukernel_name(name): $if DATATYPE in ('qp8',): /*mr_packed=*/${MR_PACKED}, /*is_igemm=*/${"true" if UKERNEL_TYPE.startswith("IGEMM") else "false"}, + /*unsigned_inputs=*/${"true" if UNSIGNED_INPUTS else "false"}, [](GemmMicrokernelTester& tester) { - tester.Test(${",\\n ".join(TEST_ARGS)}); + tester.${TEST_FUN}(${",\\n ".join(TEST_ARGS)}); $if ISA_CHECK: }, []() { @@ -648,7 +615,7 @@ def split_ukernel_name(name): .iterations(1) $if KERNELTYPE in ['qb4w', 'qc4w']: .b_zero_point(8) - .Test(${", ".join(TEST_ARGS)}); + .${TEST_FUN}(${", ".join(TEST_ARGS)}); } } } @@ -673,7 +640,7 @@ def split_ukernel_name(name): $if NR > 1: .n(${NR}) .k(${KBLOCK}) - .Test( + .${TEST_FUN}( ${", ".join(TEST_ARGS)}, &${PROTOTYPE}); } @@ -691,6 +658,7 @@ def generate_test_cases( sr, mr_packed, k_block, + unsigned_inputs, vector_tile, init_fn, pack_fn, @@ -712,6 +680,8 @@ def generate_test_cases( mr_packed: Optional MR parameter for the left-hand packing function. k_block: Number of K values processed per one iteration of the main loop of the micro-kernel. + unsigned_inputs: whether the inputs should be converted to unsigned + integers. Some microkernels are more efficient with unsigned inputs. vector_tile: Indicates if vector tile for NR is specified in vectors rather than elements. init_fn: C name of the function to initialize microkernel parameters. @@ -779,9 +749,15 @@ def generate_test_cases( "f32": "float", }[datatype] nr_scale = {"rvv": " * xnn_init_hardware_config()->vlenb / sizeof(%s)" % ctype}[isa] + test_fun_name = "".join(ukernel.split("_")[1:4]).upper() + if test_fun_name in {"QP8F32QC8W"}: + test_fun_name = "_".join(["Test", test_fun_name]) + else: + test_fun_name = "Test" test_args = { "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""), "TEST_ARGS": test_args, + "TEST_FUN": test_fun_name, "UKERNEL_TYPE": ukernel_type.upper(), "DATATYPE": datatype, "KERNELTYPE": kerneltype, @@ -792,6 +768,7 @@ def generate_test_cases( "SR": sr, "MR_PACKED": mr_packed, "KBLOCK": k_block, + "UNSIGNED_INPUTS": unsigned_inputs, "NR_SCALE": nr_scale, "ADJKBLOCK": 2 * k_block if is_pipelined else k_block, "IS_PIPELINED": is_pipelined, @@ -907,6 +884,10 @@ def main(args): for ukernel_spec in spec_yaml: name = ukernel_spec["name"] k_block = int(ukernel_spec["k-block"]) + if "unsigned-inputs" in ukernel_spec: + unsigned_inputs = int(ukernel_spec["unsigned-inputs"]) + else: + unsigned_inputs = False init_fn = ukernel_spec.get("init") pack_fn = ukernel_spec.get("pack") packed_stride_fn = ukernel_spec.get("packed-stride") @@ -934,6 +915,7 @@ def main(args): sr, mr_packed, k_block, + unsigned_inputs, vector_tile, init_fn, pack_fn, diff --git a/tools/generate-raddextexp-test.py b/tools/generate-raddextexp-test.py deleted file mode 100755 index 404d1ba73342..000000000000 --- a/tools/generate-raddextexp-test.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python -# Copyright 2019 Google LLC -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import codecs -import math -import os -import re -import sys -import yaml - -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -import xngen -import xnncommon - - -parser = argparse.ArgumentParser( - description='RAddExtExp microkernel test generator') -parser.add_argument("-s", "--spec", metavar="FILE", required=True, - help="Specification (YAML) file") -parser.add_argument("-o", "--output", metavar="FILE", required=True, - help='Output (C++ source) file') -parser.set_defaults(defines=list()) - - -def split_ukernel_name(name): - match = re.fullmatch(r"xnn_(f16|f32)_raddextexp_ukernel__(.+)_u(\d+)(_acc(\d+))?", name) - if match is None: - raise ValueError("Unexpected microkernel name: " + name) - elements_tile = int(match.group(3)) - - arch, isa, assembly = xnncommon.parse_target_name(target_name=match.group(2)) - return elements_tile, arch, isa - - -RADDEXTEXP_TEST_TEMPLATE = """\ -TEST(${TEST_NAME}, elements_eq_${ELEMENTS_TILE}) { - $if ISA_CHECK: - ${ISA_CHECK}; - RAddExtExpMicrokernelTester() - .elements(${ELEMENTS_TILE}) - .Test(${TEST_FUNCTION}); -} - -$if ELEMENTS_TILE > 1: - TEST(${TEST_NAME}, elements_div_${ELEMENTS_TILE}) { - $if ISA_CHECK: - ${ISA_CHECK}; - for (size_t elements = ${ELEMENTS_TILE*2}; elements < ${ELEMENTS_TILE*10}; elements += ${ELEMENTS_TILE}) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(${TEST_FUNCTION}); - } - } - - TEST(${TEST_NAME}, elements_lt_${ELEMENTS_TILE}) { - $if ISA_CHECK: - ${ISA_CHECK}; - for (size_t elements = 1; elements < ${ELEMENTS_TILE}; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(${TEST_FUNCTION}); - } - } - -TEST(${TEST_NAME}, elements_gt_${ELEMENTS_TILE}) { - $if ISA_CHECK: - ${ISA_CHECK}; - for (size_t elements = ${ELEMENTS_TILE+1}; elements < ${10 if ELEMENTS_TILE == 1 else ELEMENTS_TILE*2}; elements++) { - RAddExtExpMicrokernelTester() - .elements(elements) - .Test(${TEST_FUNCTION}); - } -} -""" - - -def generate_test_cases(ukernel, elements_tile, isa): - """Generates all tests cases for a RAddExtExp micro-kernel. - - Args: - ukernel: C name of the micro-kernel function. - elements_tile: Number of batch elements processed per one iteration of the - inner loop of the micro-kernel. - isa: instruction set required to run the micro-kernel. Generated unit test - will skip execution if the host processor doesn't support this ISA. - - Returns: - Code for the test case. - """ - _, test_name = ukernel.split("_", 1) - _, datatype, _ = ukernel.split("_", 2) - return xngen.preprocess(RADDEXTEXP_TEST_TEMPLATE, { - "TEST_FUNCTION": ukernel, - "TEST_NAME": test_name.upper().replace("UKERNEL_", ""), - "DATATYPE": datatype, - "ELEMENTS_TILE": elements_tile, - "ISA_CHECK": xnncommon.generate_isa_check_macro(isa), - }) - - -def main(args): - options = parser.parse_args(args) - - with codecs.open(options.spec, "r", encoding="utf-8") as spec_file: - spec_yaml = yaml.safe_load(spec_file) - if not isinstance(spec_yaml, list): - raise ValueError("expected a list of micro-kernels in the spec") - - tests = """\ -// Copyright 2019 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. -// -// Auto-generated file. Do not edit! -// Specification: {specification} -// Generator: {generator} - - -#include -#include "xnnpack/common.h" -#include "xnnpack/isa-checks.h" -#include "xnnpack/raddextexp.h" -#include "raddextexp-microkernel-tester.h" -""".format(specification=options.spec, generator=sys.argv[0]) - - for ukernel_spec in spec_yaml: - name = ukernel_spec["name"] - elements_tile, arch, isa = split_ukernel_name(name) - - test_case = generate_test_cases(name, elements_tile, isa) - tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa) - - xnncommon.overwrite_if_changed(options.output, tests) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/tools/update-microkernels.py b/tools/update-microkernels.py index e6ecd944c78b..c41b7f6e2348 100755 --- a/tools/update-microkernels.py +++ b/tools/update-microkernels.py @@ -93,6 +93,7 @@ 'wasm32', 'wasmsimd32', 'wasmrelaxedsimd32', + 'amd64', }) _MICROKERNEL_NAME_REGEX = re.compile( @@ -273,7 +274,7 @@ def main(args): for component in basename.split('-'): if component in _ARCH_LIST: arch = component - elif component in _ISA_LIST: + if component in _ISA_LIST: isa = _ISA_MAP.get(component, component) key = isa if arch is None else f'{isa}_{arch}' c_microkernels_per_isa[key].append(filepath) diff --git a/tools/xnncommon.py b/tools/xnncommon.py index 38551a0545c8..3583ec40bdf0 100644 --- a/tools/xnncommon.py +++ b/tools/xnncommon.py @@ -29,6 +29,7 @@ def _remove_duplicate_newlines(text): "aarch64": "XNN_ARCH_ARM64", "x86-32": "XNN_ARCH_X86", "x86-64": "XNN_ARCH_X86_64", + "amd64": "XNN_ARCH_X86_64", "hexagon": "XNN_ARCH_HEXAGON", "riscv": "XNN_ARCH_RISCV", "wasm": "XNN_ARCH_WASM", @@ -298,4 +299,4 @@ def make_multiline_macro(x): lines = x.strip().split('\n') max_len = max([len(i) for i in lines]) lines = [i.ljust(max_len) + "\\" for i in lines] - return "\n".join(lines)[:-1].strip() + "\n" \ No newline at end of file + return "\n".join(lines)[:-1].strip() + "\n"